e61aae5b3aff7dde75b6f964300fdc0f7e57c12d
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( 
4         vectTyCon,
5         vectAndLiftType,
6         vectType,
7         vectTypeEnv,
8         buildPADict,
9         fromVect
10 )
11 where
12 import VectUtils
13 import Vectorise.Env
14 import Vectorise.Convert
15 import Vectorise.Vect
16 import Vectorise.Monad
17 import Vectorise.Builtins
18 import Vectorise.Type.Type
19 import Vectorise.Type.TyConDecl
20 import Vectorise.Type.Classify
21 import Vectorise.Type.Repr
22 import Vectorise.Type.PADict
23 import Vectorise.Utils.Closure
24 import Vectorise.Utils.Hoisting
25
26 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
27 import BasicTypes
28 import CoreSyn
29 import CoreUtils
30 import CoreUnfold
31 import BuildTyCl
32 import DataCon
33 import TyCon
34 import Type
35 import Coercion
36 import FamInstEnv        ( FamInst, mkLocalFamInst )
37 import OccName
38 import Id
39 import MkId
40 import Var
41 import Name              ( Name, getOccName )
42 import NameEnv
43
44 import Unique
45 import UniqFM
46 import Util
47
48 import Outputable
49 import FastString
50
51 import MonadUtils
52 import Control.Monad
53 import Data.List
54
55 debug           = False
56 dtrace s x      = if debug then pprTrace "VectType" s x else x
57
58
59 -- | Vectorise a type environment.
60 --   The type environment contains all the type things defined in a module.
61 vectTypeEnv 
62         :: TypeEnv
63         -> VM ( TypeEnv                 -- Vectorised type environment.
64               , [FamInst]               -- New type family instances.
65               , [(Var, CoreExpr)])      -- New top level bindings.
66         
67 vectTypeEnv env
68  = dtrace (ppr env)
69  $ do
70       cs <- readGEnv $ mk_map . global_tycons
71
72       -- Split the list of TyCons into the ones we have to vectorise vs the
73       -- ones we can pass through unchanged. We also pass through algebraic 
74       -- types that use non Haskell98 features, as we don't handle those.
75       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
76           keep_dcs             = concatMap tyConDataCons keep_tcs
77
78       zipWithM_ defTyCon   keep_tcs keep_tcs
79       zipWithM_ defDataCon keep_dcs keep_dcs
80
81       new_tcs <- vectTyConDecls conv_tcs
82
83       let orig_tcs = keep_tcs ++ conv_tcs
84
85       -- We don't need to make new representation types for dictionary
86       -- constructors. The constructors are always fully applied, and we don't 
87       -- need to lift them to arrays as a dictionary of a particular type
88       -- always has the same value.
89       let vect_tcs = filter (not . isClassTyCon) 
90                    $ keep_tcs ++ new_tcs
91
92       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
93         do
94           defTyConPAs (zipLazy vect_tcs dfuns')
95           reprs     <- mapM tyConRepr vect_tcs
96           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
97           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
98
99           dfuns     <- sequence 
100                     $  zipWith5 buildTyConBindings
101                                orig_tcs
102                                vect_tcs
103                                repr_tcs
104                                pdata_tcs
105                                reprs
106
107           binds     <- takeHoisted
108           return (dfuns, binds, repr_tcs ++ pdata_tcs)
109
110       let all_new_tcs = new_tcs ++ inst_tcs
111
112       let new_env = extendTypeEnvList env
113                        (map ATyCon all_new_tcs
114                         ++ [ADataCon dc | tc <- all_new_tcs
115                                         , dc <- tyConDataCons tc])
116
117       return (new_env, map mkLocalFamInst inst_tcs, binds)
118   where
119     tycons = typeEnvTyCons env
120     groups = tyConGroups tycons
121
122     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
123
124
125 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
126 mk_fam_inst fam_tc arg_tc
127   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
128
129
130 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
131 buildPReprTyCon orig_tc vect_tc repr
132   = do
133       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
134       -- rhs_ty   <- buildPReprType vect_tc
135       rhs_ty   <- sumReprType repr
136       prepr_tc <- builtin preprTyCon
137       liftDs $ buildSynTyCon name
138                              tyvars
139                              (SynonymTyCon rhs_ty)
140                              (typeKind rhs_ty)
141                              (Just $ mk_fam_inst prepr_tc vect_tc)
142   where
143     tyvars = tyConTyVars vect_tc
144
145
146
147 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
148 buildPRDict vect_tc prepr_tc _ r
149   = do
150       dict <- sum_dict r
151       pr_co <- mkBuiltinCo prTyCon
152       let co = mkAppCoercion pr_co
153              . mkSymCoercion
154              $ mkTyConApp arg_co ty_args
155       return (mkCoerce co dict)
156   where
157     ty_args = mkTyVarTys (tyConTyVars vect_tc)
158     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
159
160     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
161     sum_dict (UnarySum r) = con_dict r
162     sum_dict (Sum { repr_sum_tc  = sum_tc
163                   , repr_con_tys = tys
164                   , repr_cons    = cons
165                   })
166       = do
167           dicts <- mapM con_dict cons
168           dfun  <- prDFunOfTyCon sum_tc
169           return $ dfun `mkTyApps` tys `mkApps` dicts
170
171     con_dict (ConRepr _ r) = prod_dict r
172
173     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
174     prod_dict (UnaryProd r) = comp_dict r
175     prod_dict (Prod { repr_tup_tc   = tup_tc
176                     , repr_comp_tys = tys
177                     , repr_comps    = comps })
178       = do
179           dicts <- mapM comp_dict comps
180           dfun <- prDFunOfTyCon tup_tc
181           return $ dfun `mkTyApps` tys `mkApps` dicts
182
183     comp_dict (Keep _ pr) = return pr
184     comp_dict (Wrap ty)   = wrapPR ty
185
186
187 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
188 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
189   do
190     name' <- cloneName mkPDataTyConOcc orig_name
191     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
192     pdata <- builtin pdataTyCon
193
194     liftDs $ buildAlgTyCon name'
195                            tyvars
196                            []          -- no stupid theta
197                            rhs
198                            rec_flag    -- FIXME: is this ok?
199                            False       -- FIXME: no generics
200                            False       -- not GADT syntax
201                            (Just $ mk_fam_inst pdata vect_tc)
202   where
203     orig_name = tyConName orig_tc
204     tyvars = tyConTyVars vect_tc
205     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
206
207
208 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
209 buildPDataTyConRhs orig_name vect_tc repr_tc repr
210   = do
211       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
212       return $ DataTyCon { data_cons = [data_con], is_enum = False }
213
214 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
215 buildPDataDataCon orig_name vect_tc repr_tc repr
216   = do
217       dc_name  <- cloneName mkPDataDataConOcc orig_name
218       comp_tys <- sum_tys repr
219
220       liftDs $ buildDataCon dc_name
221                             False                  -- not infix
222                             (map (const HsNoBang) comp_tys)
223                             []                     -- no field labels
224                             tvs
225                             []                     -- no existentials
226                             []                     -- no eq spec
227                             []                     -- no context
228                             comp_tys
229                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
230                             repr_tc
231   where
232     tvs   = tyConTyVars vect_tc
233
234     sum_tys EmptySum = return []
235     sum_tys (UnarySum r) = con_tys r
236     sum_tys (Sum { repr_sel_ty = sel_ty
237                  , repr_cons   = cons })
238       = liftM (sel_ty :) (concatMapM con_tys cons)
239
240     con_tys (ConRepr _ r) = prod_tys r
241
242     prod_tys EmptyProd = return []
243     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
244     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
245
246     comp_ty r = mkPDataType (compOrigType r)
247
248
249 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
250                    -> VM Var
251 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
252   = do
253       vectDataConWorkers orig_tc vect_tc pdata_tc
254       buildPADict vect_tc prepr_tc pdata_tc repr
255
256 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
257 vectDataConWorkers orig_tc vect_tc arr_tc
258   = do
259       bs <- sequence
260           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
261           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
262                                  rep_tys
263                                  (inits rep_tys)
264                                  (tail $ tails rep_tys)
265       mapM_ (uncurry hoistBinding) bs
266   where
267     tyvars   = tyConTyVars vect_tc
268     var_tys  = mkTyVarTys tyvars
269     ty_args  = map Type var_tys
270     res_ty   = mkTyConApp vect_tc var_tys
271
272     cons     = tyConDataCons vect_tc
273     arity    = length cons
274     [arr_dc] = tyConDataCons arr_tc
275
276     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
277
278
279     mk_data_con con tys pre post
280       = liftM2 (,) (vect_data_con con)
281                    (lift_data_con tys pre post (mkDataConTag con))
282
283     sel_replicate len tag
284       | arity > 1 = do
285                       rep <- builtin (selReplicate arity)
286                       return [rep `mkApps` [len, tag]]
287
288       | otherwise = return []
289
290     vect_data_con con = return $ mkConApp con ty_args
291     lift_data_con tys pre_tys post_tys tag
292       = do
293           len  <- builtin liftingContext
294           args <- mapM (newLocalVar (fsLit "xs"))
295                   =<< mapM mkPDataType tys
296
297           sel  <- sel_replicate (Var len) tag
298
299           pre   <- mapM emptyPD (concat pre_tys)
300           post  <- mapM emptyPD (concat post_tys)
301
302           return . mkLams (len : args)
303                  . wrapFamInstBody arr_tc var_tys
304                  . mkConApp arr_dc
305                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
306
307     def_worker data_con arg_tys mk_body
308       = do
309           arity <- polyArity tyvars
310           body <- closedV
311                 . inBind orig_worker
312                 . polyAbstract tyvars $ \args ->
313                   liftM (mkLams (tyvars ++ args) . vectorised)
314                 $ buildClosures tyvars [] arg_tys res_ty mk_body
315
316           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
317           let vect_worker = raw_worker `setIdUnfolding`
318                               mkInlineRule body (Just arity)
319           defGlobalVar orig_worker vect_worker
320           return (vect_worker, body)
321       where
322         orig_worker = dataConWorkId data_con
323
324 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
325 buildPADict vect_tc prepr_tc arr_tc repr
326   = polyAbstract tvs $ \args ->
327     do
328       method_ids <- mapM (method args) paMethods
329
330       pa_tc  <- builtin paTyCon
331       pa_dc  <- builtin paDataCon
332       let dict = mkLams (tvs ++ args)
333                $ mkConApp pa_dc
334                $ Type inst_ty : map (method_call args) method_ids
335
336           dfun_ty = mkForAllTys tvs
337                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
338
339       raw_dfun <- newExportedVar dfun_name dfun_ty
340       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
341                           `setInlinePragma` dfunInlinePragma
342
343       hoistBinding dfun dict
344       return dfun
345   where
346     tvs = tyConTyVars vect_tc
347     arg_tys = mkTyVarTys tvs
348     inst_ty = mkTyConApp vect_tc arg_tys
349
350     dfun_name = mkPADFunOcc (getOccName vect_tc)
351
352     method args (name, build)
353       = localV
354       $ do
355           expr <- build vect_tc prepr_tc arr_tc repr
356           let body = mkLams (tvs ++ args) expr
357           raw_var <- newExportedVar (method_name name) (exprType body)
358           let var = raw_var
359                       `setIdUnfolding` mkInlineRule body (Just (length args))
360                       `setInlinePragma` alwaysInlinePragma
361           hoistBinding var body
362           return var
363
364     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
365
366     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
367
368
369 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
370 paMethods = [("dictPRepr",    buildPRDict),
371              ("toPRepr",      buildToPRepr),
372              ("fromPRepr",    buildFromPRepr),
373              ("toArrPRepr",   buildToArrPRepr),
374              ("fromArrPRepr", buildFromArrPRepr)]
375
376