Finish breaking up vectoriser utils
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / Env.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module Vectorise.Type.Env ( 
4         vectTypeEnv,
5 )
6 where
7 import Vectorise.Env
8 import Vectorise.Vect
9 import Vectorise.Monad
10 import Vectorise.Builtins
11 import Vectorise.Type.TyConDecl
12 import Vectorise.Type.Classify
13 import Vectorise.Type.PADict
14 import Vectorise.Type.PData
15 import Vectorise.Type.PRepr
16 import Vectorise.Type.Repr
17 import Vectorise.Utils
18
19 import HscTypes
20 import CoreSyn
21 import CoreUtils
22 import CoreUnfold
23 import DataCon
24 import TyCon
25 import Type
26 import FamInstEnv
27 import OccName
28 import Id
29 import MkId
30 import Var
31 import NameEnv
32
33 import Unique
34 import UniqFM
35 import Util
36 import Outputable
37 import FastString
38 import MonadUtils
39 import Control.Monad
40 import Data.List
41
42 debug           = False
43 dtrace s x      = if debug then pprTrace "VectType" s x else x
44
45
46 -- | Vectorise a type environment.
47 --   The type environment contains all the type things defined in a module.
48 vectTypeEnv 
49         :: TypeEnv
50         -> VM ( TypeEnv                 -- Vectorised type environment.
51               , [FamInst]               -- New type family instances.
52               , [(Var, CoreExpr)])      -- New top level bindings.
53         
54 vectTypeEnv env
55  = dtrace (ppr env)
56  $ do
57       cs <- readGEnv $ mk_map . global_tycons
58
59       -- Split the list of TyCons into the ones we have to vectorise vs the
60       -- ones we can pass through unchanged. We also pass through algebraic 
61       -- types that use non Haskell98 features, as we don't handle those.
62       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
63           keep_dcs             = concatMap tyConDataCons keep_tcs
64
65       zipWithM_ defTyCon   keep_tcs keep_tcs
66       zipWithM_ defDataCon keep_dcs keep_dcs
67
68       new_tcs <- vectTyConDecls conv_tcs
69
70       let orig_tcs = keep_tcs ++ conv_tcs
71
72       -- We don't need to make new representation types for dictionary
73       -- constructors. The constructors are always fully applied, and we don't 
74       -- need to lift them to arrays as a dictionary of a particular type
75       -- always has the same value.
76       let vect_tcs = filter (not . isClassTyCon) 
77                    $ keep_tcs ++ new_tcs
78
79       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
80         do
81           defTyConPAs (zipLazy vect_tcs dfuns')
82           reprs     <- mapM tyConRepr vect_tcs
83           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
84           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
85
86           dfuns     <- sequence 
87                     $  zipWith5 buildTyConBindings
88                                orig_tcs
89                                vect_tcs
90                                repr_tcs
91                                pdata_tcs
92                                reprs
93
94           binds     <- takeHoisted
95           return (dfuns, binds, repr_tcs ++ pdata_tcs)
96
97       let all_new_tcs = new_tcs ++ inst_tcs
98
99       let new_env = extendTypeEnvList env
100                        (map ATyCon all_new_tcs
101                         ++ [ADataCon dc | tc <- all_new_tcs
102                                         , dc <- tyConDataCons tc])
103
104       return (new_env, map mkLocalFamInst inst_tcs, binds)
105   where
106     tycons = typeEnvTyCons env
107     groups = tyConGroups tycons
108
109     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
110
111
112
113 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
114 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
115  = do vectDataConWorkers orig_tc vect_tc pdata_tc
116       buildPADict vect_tc prepr_tc pdata_tc repr
117
118
119 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
120 vectDataConWorkers orig_tc vect_tc arr_tc
121  = do bs <- sequence
122           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
123           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
124                                  rep_tys
125                                  (inits rep_tys)
126                                  (tail $ tails rep_tys)
127       mapM_ (uncurry hoistBinding) bs
128  where
129     tyvars   = tyConTyVars vect_tc
130     var_tys  = mkTyVarTys tyvars
131     ty_args  = map Type var_tys
132     res_ty   = mkTyConApp vect_tc var_tys
133
134     cons     = tyConDataCons vect_tc
135     arity    = length cons
136     [arr_dc] = tyConDataCons arr_tc
137
138     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
139
140
141     mk_data_con con tys pre post
142       = liftM2 (,) (vect_data_con con)
143                    (lift_data_con tys pre post (mkDataConTag con))
144
145     sel_replicate len tag
146       | arity > 1 = do
147                       rep <- builtin (selReplicate arity)
148                       return [rep `mkApps` [len, tag]]
149
150       | otherwise = return []
151
152     vect_data_con con = return $ mkConApp con ty_args
153     lift_data_con tys pre_tys post_tys tag
154       = do
155           len  <- builtin liftingContext
156           args <- mapM (newLocalVar (fsLit "xs"))
157                   =<< mapM mkPDataType tys
158
159           sel  <- sel_replicate (Var len) tag
160
161           pre   <- mapM emptyPD (concat pre_tys)
162           post  <- mapM emptyPD (concat post_tys)
163
164           return . mkLams (len : args)
165                  . wrapFamInstBody arr_tc var_tys
166                  . mkConApp arr_dc
167                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
168
169     def_worker data_con arg_tys mk_body
170       = do
171           arity <- polyArity tyvars
172           body <- closedV
173                 . inBind orig_worker
174                 . polyAbstract tyvars $ \args ->
175                   liftM (mkLams (tyvars ++ args) . vectorised)
176                 $ buildClosures tyvars [] arg_tys res_ty mk_body
177
178           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
179           let vect_worker = raw_worker `setIdUnfolding`
180                               mkInlineRule body (Just arity)
181           defGlobalVar orig_worker vect_worker
182           return (vect_worker, body)
183       where
184         orig_worker = dataConWorkId data_con
185