99c17464dd81075851cc6e1145c3d4c25e03ae07
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / Env.hs
1 {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
2 {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
3 -- Roman likes local bindings
4 -- If this module lives on I'd like to get rid of this flag in due course
5
6 module Vectorise.Type.Env ( 
7         vectTypeEnv,
8 )
9 where
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.TyConDecl
15 import Vectorise.Type.Classify
16 import Vectorise.Type.PADict
17 import Vectorise.Type.PData
18 import Vectorise.Type.PRepr
19 import Vectorise.Type.Repr
20 import Vectorise.Utils
21
22 import HscTypes
23 import CoreSyn
24 import CoreUtils
25 import CoreUnfold
26 import DataCon
27 import TyCon
28 import Type
29 import FamInstEnv
30 import OccName
31 import Id
32 import MkId
33 import Var
34 import NameEnv
35
36 import Unique
37 import UniqFM
38 import Util
39 import Outputable
40 import FastString
41 import MonadUtils
42 import Control.Monad
43 import Data.List
44
45 debug           = False
46 dtrace s x      = if debug then pprTrace "VectType" s x else x
47
48 -- | Vectorise a type environment.
49 --   The type environment contains all the type things defined in a module.
50 vectTypeEnv 
51         :: TypeEnv
52         -> VM ( TypeEnv                 -- Vectorised type environment.
53               , [FamInst]               -- New type family instances.
54               , [(Var, CoreExpr)])      -- New top level bindings.
55         
56 vectTypeEnv env
57  = dtrace (ppr env)
58  $ do
59       cs <- readGEnv $ mk_map . global_tycons
60
61       -- Split the list of TyCons into the ones we have to vectorise vs the
62       -- ones we can pass through unchanged. We also pass through algebraic 
63       -- types that use non Haskell98 features, as we don't handle those.
64       let tycons               = typeEnvTyCons env
65           groups               = tyConGroups tycons
66
67       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
68           orig_tcs             = keep_tcs ++ conv_tcs
69           keep_dcs             = concatMap tyConDataCons keep_tcs
70
71       -- Just use the unvectorised versions of these constructors in vectorised code.
72       zipWithM_ defTyCon   keep_tcs keep_tcs
73       zipWithM_ defDataCon keep_dcs keep_dcs
74
75       -- Vectorise all the declarations.
76       new_tcs      <- vectTyConDecls conv_tcs
77
78       -- We don't need to make new representation types for dictionary
79       -- constructors. The constructors are always fully applied, and we don't 
80       -- need to lift them to arrays as a dictionary of a particular type
81       -- always has the same value.
82       let vect_tcs  = filter (not . isClassTyCon) 
83                     $ keep_tcs ++ new_tcs
84
85       -- Create PRepr and PData instances for the vectorised types.
86       -- We get back the binds for the instance functions, 
87       -- and some new type constructors for the representation types.
88       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
89         do
90           defTyConPAs (zipLazy vect_tcs dfuns')
91           reprs     <- mapM tyConRepr vect_tcs
92           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
93           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
94
95           dfuns     <- sequence 
96                     $  zipWith5 buildTyConBindings
97                                orig_tcs
98                                vect_tcs
99                                repr_tcs
100                                pdata_tcs
101                                reprs
102
103           binds     <- takeHoisted
104           return (dfuns, binds, repr_tcs ++ pdata_tcs)
105
106       -- The new type constructors are the vectorised versions of the originals, 
107       -- plus the new type constructors that we use for the representations.
108       let all_new_tcs = new_tcs ++ inst_tcs
109
110       let new_env     =  extendTypeEnvList env
111                       $  map ATyCon all_new_tcs
112                       ++ [ADataCon dc | tc <- all_new_tcs
113                                       , dc <- tyConDataCons tc]
114
115       return (new_env, map mkLocalFamInst inst_tcs, binds)
116
117    where
118     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
119
120
121
122 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
123 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
124  = do vectDataConWorkers orig_tc vect_tc pdata_tc
125       buildPADict vect_tc prepr_tc pdata_tc repr
126
127
128 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
129 vectDataConWorkers orig_tc vect_tc arr_tc
130  = do bs <- sequence
131           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
132           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
133                                  rep_tys
134                                  (inits rep_tys)
135                                  (tail $ tails rep_tys)
136       mapM_ (uncurry hoistBinding) bs
137  where
138     tyvars   = tyConTyVars vect_tc
139     var_tys  = mkTyVarTys tyvars
140     ty_args  = map Type var_tys
141     res_ty   = mkTyConApp vect_tc var_tys
142
143     cons     = tyConDataCons vect_tc
144     arity    = length cons
145     [arr_dc] = tyConDataCons arr_tc
146
147     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
148
149
150     mk_data_con con tys pre post
151       = liftM2 (,) (vect_data_con con)
152                    (lift_data_con tys pre post (mkDataConTag con))
153
154     sel_replicate len tag
155       | arity > 1 = do
156                       rep <- builtin (selReplicate arity)
157                       return [rep `mkApps` [len, tag]]
158
159       | otherwise = return []
160
161     vect_data_con con = return $ mkConApp con ty_args
162     lift_data_con tys pre_tys post_tys tag
163       = do
164           len  <- builtin liftingContext
165           args <- mapM (newLocalVar (fsLit "xs"))
166                   =<< mapM mkPDataType tys
167
168           sel  <- sel_replicate (Var len) tag
169
170           pre   <- mapM emptyPD (concat pre_tys)
171           post  <- mapM emptyPD (concat post_tys)
172
173           return . mkLams (len : args)
174                  . wrapFamInstBody arr_tc var_tys
175                  . mkConApp arr_dc
176                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
177
178     def_worker data_con arg_tys mk_body
179       = do
180           arity <- polyArity tyvars
181           body <- closedV
182                 . inBind orig_worker
183                 . polyAbstract tyvars $ \args ->
184                   liftM (mkLams (tyvars ++ args) . vectorised)
185                 $ buildClosures tyvars [] arg_tys res_ty mk_body
186
187           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
188           let vect_worker = raw_worker `setIdUnfolding`
189                               mkInlineUnfolding (Just arity) body
190           defGlobalVar orig_worker vect_worker
191           return (vect_worker, body)
192       where
193         orig_worker = dataConWorkId data_con
194