43ff97c69a9fdc06c6199e8ef28d426e744ebf5c
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / Env.hs
1 {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
2 #if __GLASGOW_HASKELL__ >= 611
3 {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
4 #endif
5 -- Roman likes local bindings
6 -- If this module lives on I'd like to get rid of this flag in due course
7
8 module Vectorise.Type.Env ( 
9         vectTypeEnv,
10 )
11 where
12 import Vectorise.Env
13 import Vectorise.Vect
14 import Vectorise.Monad
15 import Vectorise.Builtins
16 import Vectorise.Type.TyConDecl
17 import Vectorise.Type.Classify
18 import Vectorise.Type.PADict
19 import Vectorise.Type.PData
20 import Vectorise.Type.PRepr
21 import Vectorise.Type.Repr
22 import Vectorise.Utils
23
24 import HscTypes
25 import CoreSyn
26 import CoreUtils
27 import CoreUnfold
28 import DataCon
29 import TyCon
30 import Type
31 import FamInstEnv
32 import OccName
33 import Id
34 import MkId
35 import Var
36 import NameEnv
37
38 import Unique
39 import UniqFM
40 import Util
41 import Outputable
42 import FastString
43 import MonadUtils
44 import Control.Monad
45 import Data.List
46
47 debug           = False
48 dtrace s x      = if debug then pprTrace "VectType" s x else x
49
50 -- | Vectorise a type environment.
51 --   The type environment contains all the type things defined in a module.
52 vectTypeEnv 
53         :: TypeEnv
54         -> VM ( TypeEnv                 -- Vectorised type environment.
55               , [FamInst]               -- New type family instances.
56               , [(Var, CoreExpr)])      -- New top level bindings.
57         
58 vectTypeEnv env
59  = dtrace (ppr env)
60  $ do
61       cs <- readGEnv $ mk_map . global_tycons
62
63       -- Split the list of TyCons into the ones we have to vectorise vs the
64       -- ones we can pass through unchanged. We also pass through algebraic 
65       -- types that use non Haskell98 features, as we don't handle those.
66       let tycons               = typeEnvTyCons env
67           groups               = tyConGroups tycons
68
69       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
70           orig_tcs             = keep_tcs ++ conv_tcs
71           keep_dcs             = concatMap tyConDataCons keep_tcs
72
73       -- Just use the unvectorised versions of these constructors in vectorised code.
74       zipWithM_ defTyCon   keep_tcs keep_tcs
75       zipWithM_ defDataCon keep_dcs keep_dcs
76
77       -- Vectorise all the declarations.
78       new_tcs      <- vectTyConDecls conv_tcs
79
80       -- We don't need to make new representation types for dictionary
81       -- constructors. The constructors are always fully applied, and we don't 
82       -- need to lift them to arrays as a dictionary of a particular type
83       -- always has the same value.
84       let vect_tcs  = filter (not . isClassTyCon) 
85                     $ keep_tcs ++ new_tcs
86
87       -- Create PRepr and PData instances for the vectorised types.
88       -- We get back the binds for the instance functions, 
89       -- and some new type constructors for the representation types.
90       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
91         do
92           defTyConPAs (zipLazy vect_tcs dfuns')
93           reprs     <- mapM tyConRepr vect_tcs
94           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
95           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
96
97           dfuns     <- sequence 
98                     $  zipWith5 buildTyConBindings
99                                orig_tcs
100                                vect_tcs
101                                repr_tcs
102                                pdata_tcs
103                                reprs
104
105           binds     <- takeHoisted
106           return (dfuns, binds, repr_tcs ++ pdata_tcs)
107
108       -- The new type constructors are the vectorised versions of the originals, 
109       -- plus the new type constructors that we use for the representations.
110       let all_new_tcs = new_tcs ++ inst_tcs
111
112       let new_env     =  extendTypeEnvList env
113                       $  map ATyCon all_new_tcs
114                       ++ [ADataCon dc | tc <- all_new_tcs
115                                       , dc <- tyConDataCons tc]
116
117       return (new_env, map mkLocalFamInst inst_tcs, binds)
118
119    where
120     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
121
122
123
124 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
125 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
126  = do vectDataConWorkers orig_tc vect_tc pdata_tc
127       buildPADict vect_tc prepr_tc pdata_tc repr
128
129
130 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
131 vectDataConWorkers orig_tc vect_tc arr_tc
132  = do bs <- sequence
133           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
134           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
135                                  rep_tys
136                                  (inits rep_tys)
137                                  (tail $ tails rep_tys)
138       mapM_ (uncurry hoistBinding) bs
139  where
140     tyvars   = tyConTyVars vect_tc
141     var_tys  = mkTyVarTys tyvars
142     ty_args  = map Type var_tys
143     res_ty   = mkTyConApp vect_tc var_tys
144
145     cons     = tyConDataCons vect_tc
146     arity    = length cons
147     [arr_dc] = tyConDataCons arr_tc
148
149     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
150
151
152     mk_data_con con tys pre post
153       = liftM2 (,) (vect_data_con con)
154                    (lift_data_con tys pre post (mkDataConTag con))
155
156     sel_replicate len tag
157       | arity > 1 = do
158                       rep <- builtin (selReplicate arity)
159                       return [rep `mkApps` [len, tag]]
160
161       | otherwise = return []
162
163     vect_data_con con = return $ mkConApp con ty_args
164     lift_data_con tys pre_tys post_tys tag
165       = do
166           len  <- builtin liftingContext
167           args <- mapM (newLocalVar (fsLit "xs"))
168                   =<< mapM mkPDataType tys
169
170           sel  <- sel_replicate (Var len) tag
171
172           pre   <- mapM emptyPD (concat pre_tys)
173           post  <- mapM emptyPD (concat post_tys)
174
175           return . mkLams (len : args)
176                  . wrapFamInstBody arr_tc var_tys
177                  . mkConApp arr_dc
178                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
179
180     def_worker data_con arg_tys mk_body
181       = do
182           arity <- polyArity tyvars
183           body <- closedV
184                 . inBind orig_worker
185                 . polyAbstract tyvars $ \args ->
186                   liftM (mkLams (tyvars ++ args) . vectorised)
187                 $ buildClosures tyvars [] arg_tys res_ty mk_body
188
189           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
190           let vect_worker = raw_worker `setIdUnfolding`
191                               mkInlineUnfolding (Just arity) body
192           defGlobalVar orig_worker vect_worker
193           return (vect_worker, body)
194       where
195         orig_worker = dataConWorkId data_con
196