61a52bc4b7b07e0ce15f9252e28d00656d187f38
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / Env.hs
1 {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
2 {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
3 -- Roman likes local bindings
4 -- If this module lives on I'd like to get rid of this flag in due course
5
6 module Vectorise.Type.Env ( 
7         vectTypeEnv,
8 )
9 where
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.TyConDecl
15 import Vectorise.Type.Classify
16 import Vectorise.Type.PADict
17 import Vectorise.Type.PData
18 import Vectorise.Type.PRepr
19 import Vectorise.Type.Repr
20 import Vectorise.Utils
21
22 import HscTypes
23 import CoreSyn
24 import CoreUtils
25 import CoreUnfold
26 import DataCon
27 import TyCon
28 import Type
29 import FamInstEnv
30 import OccName
31 import Id
32 import MkId
33 import Var
34 import NameEnv
35
36 import Unique
37 import UniqFM
38 import Util
39 import Outputable
40 import FastString
41 import MonadUtils
42 import Control.Monad
43 import Data.List
44
45 debug           = False
46 dtrace s x      = if debug then pprTrace "VectType" s x else x
47
48 -- | Vectorise a type environment.
49 --   The type environment contains all the type things defined in a module.
50 vectTypeEnv 
51         :: TypeEnv
52         -> VM ( TypeEnv                 -- Vectorised type environment.
53               , [FamInst]               -- New type family instances.
54               , [(Var, CoreExpr)])      -- New top level bindings.
55         
56 vectTypeEnv env
57  = dtrace (ppr env)
58  $ do
59       cs <- readGEnv $ mk_map . global_tycons
60
61       -- Split the list of TyCons into the ones we have to vectorise vs the
62       -- ones we can pass through unchanged. We also pass through algebraic 
63       -- types that use non Haskell98 features, as we don't handle those.
64       let tycons               = typeEnvTyCons env
65           groups               = tyConGroups tycons
66
67       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
68           orig_tcs             = keep_tcs ++ conv_tcs
69           keep_dcs             = concatMap tyConDataCons keep_tcs
70
71       -- Just use the unvectorised versions of these constructors in vectorised code.
72       zipWithM_ defTyCon   keep_tcs keep_tcs
73       zipWithM_ defDataCon keep_dcs keep_dcs
74
75       -- Vectorise all the declarations.
76       new_tcs      <- vectTyConDecls conv_tcs
77
78       -- We don't need to make new representation types for dictionary
79       -- constructors. The constructors are always fully applied, and we don't 
80       -- need to lift them to arrays as a dictionary of a particular type
81       -- always has the same value.
82       let vect_tcs  = filter (not . isClassTyCon) 
83                     $ keep_tcs ++ new_tcs
84
85       reprs <- mapM tyConRepr vect_tcs
86       repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
87       pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
88       updGEnv $ extendFamEnv
89               $ map mkLocalFamInst
90               $ repr_tcs ++ pdata_tcs
91
92       -- Create PRepr and PData instances for the vectorised types.
93       -- We get back the binds for the instance functions, 
94       -- and some new type constructors for the representation types.
95       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
96         do
97           defTyConPAs (zipLazy vect_tcs dfuns')
98           reprs     <- mapM tyConRepr vect_tcs
99
100           dfuns     <- sequence 
101                     $  zipWith5 buildTyConBindings
102                                orig_tcs
103                                vect_tcs
104                                repr_tcs
105                                pdata_tcs
106                                reprs
107
108           binds     <- takeHoisted
109           return (dfuns, binds, repr_tcs ++ pdata_tcs)
110
111       -- The new type constructors are the vectorised versions of the originals, 
112       -- plus the new type constructors that we use for the representations.
113       let all_new_tcs = new_tcs ++ inst_tcs
114
115       let new_env     =  extendTypeEnvList env
116                       $  map ATyCon all_new_tcs
117                       ++ [ADataCon dc | tc <- all_new_tcs
118                                       , dc <- tyConDataCons tc]
119
120       return (new_env, map mkLocalFamInst inst_tcs, binds)
121
122    where
123     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
124
125
126
127 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
128 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
129  = do vectDataConWorkers orig_tc vect_tc pdata_tc
130       buildPADict vect_tc prepr_tc pdata_tc repr
131
132
133 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
134 vectDataConWorkers orig_tc vect_tc arr_tc
135  = do bs <- sequence
136           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
137           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
138                                  rep_tys
139                                  (inits rep_tys)
140                                  (tail $ tails rep_tys)
141       mapM_ (uncurry hoistBinding) bs
142  where
143     tyvars   = tyConTyVars vect_tc
144     var_tys  = mkTyVarTys tyvars
145     ty_args  = map Type var_tys
146     res_ty   = mkTyConApp vect_tc var_tys
147
148     cons     = tyConDataCons vect_tc
149     arity    = length cons
150     [arr_dc] = tyConDataCons arr_tc
151
152     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
153
154
155     mk_data_con con tys pre post
156       = liftM2 (,) (vect_data_con con)
157                    (lift_data_con tys pre post (mkDataConTag con))
158
159     sel_replicate len tag
160       | arity > 1 = do
161                       rep <- builtin (selReplicate arity)
162                       return [rep `mkApps` [len, tag]]
163
164       | otherwise = return []
165
166     vect_data_con con = return $ mkConApp con ty_args
167     lift_data_con tys pre_tys post_tys tag
168       = do
169           len  <- builtin liftingContext
170           args <- mapM (newLocalVar (fsLit "xs"))
171                   =<< mapM mkPDataType tys
172
173           sel  <- sel_replicate (Var len) tag
174
175           pre   <- mapM emptyPD (concat pre_tys)
176           post  <- mapM emptyPD (concat post_tys)
177
178           return . mkLams (len : args)
179                  . wrapFamInstBody arr_tc var_tys
180                  . mkConApp arr_dc
181                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
182
183     def_worker data_con arg_tys mk_body
184       = do
185           arity <- polyArity tyvars
186           body <- closedV
187                 . inBind orig_worker
188                 . polyAbstract tyvars $ \args ->
189                   liftM (mkLams (tyvars ++ args) . vectorised)
190                 $ buildClosures tyvars [] arg_tys res_ty mk_body
191
192           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
193           let vect_worker = raw_worker `setIdUnfolding`
194                               mkInlineUnfolding (Just arity) body
195           defGlobalVar orig_worker vect_worker
196           return (vect_worker, body)
197       where
198         orig_worker = dataConWorkId data_con
199