8e26ed9788b19f883ba9d9b3746e7eb45055e550
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / Env.hs
1 {-# OPTIONS_GHC -fno-warn-missing-signatures #-}
2 #if __GLASGOW_HASKELL__ >= 611
3 {-# OPTIONS_GHC -XNoMonoLocalBinds #-}
4 #endif
5 -- Roman likes local bindings
6 -- If this module lives on I'd like to get rid of this flag in due course
7
8 module Vectorise.Type.Env ( 
9         vectTypeEnv,
10 )
11 where
12 import Vectorise.Env
13 import Vectorise.Vect
14 import Vectorise.Monad
15 import Vectorise.Builtins
16 import Vectorise.Type.TyConDecl
17 import Vectorise.Type.Classify
18 import Vectorise.Type.PADict
19 import Vectorise.Type.PData
20 import Vectorise.Type.PRepr
21 import Vectorise.Type.Repr
22 import Vectorise.Utils
23
24 import HscTypes
25 import CoreSyn
26 import CoreUtils
27 import CoreUnfold
28 import DataCon
29 import TyCon
30 import Type
31 import FamInstEnv
32 import OccName
33 import Id
34 import MkId
35 import Var
36 import NameEnv
37
38 import Unique
39 import UniqFM
40 import Util
41 import Outputable
42 import FastString
43 import MonadUtils
44 import Control.Monad
45 import Data.List
46
47 debug           = False
48 dtrace s x      = if debug then pprTrace "VectType" s x else x
49
50
51 -- | Vectorise a type environment.
52 --   The type environment contains all the type things defined in a module.
53 vectTypeEnv 
54         :: TypeEnv
55         -> VM ( TypeEnv                 -- Vectorised type environment.
56               , [FamInst]               -- New type family instances.
57               , [(Var, CoreExpr)])      -- New top level bindings.
58         
59 vectTypeEnv env
60  = dtrace (ppr env)
61  $ do
62       cs <- readGEnv $ mk_map . global_tycons
63
64       -- Split the list of TyCons into the ones we have to vectorise vs the
65       -- ones we can pass through unchanged. We also pass through algebraic 
66       -- types that use non Haskell98 features, as we don't handle those.
67       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
68           keep_dcs             = concatMap tyConDataCons keep_tcs
69
70       zipWithM_ defTyCon   keep_tcs keep_tcs
71       zipWithM_ defDataCon keep_dcs keep_dcs
72
73       new_tcs <- vectTyConDecls conv_tcs
74
75       let orig_tcs = keep_tcs ++ conv_tcs
76
77       -- We don't need to make new representation types for dictionary
78       -- constructors. The constructors are always fully applied, and we don't 
79       -- need to lift them to arrays as a dictionary of a particular type
80       -- always has the same value.
81       let vect_tcs = filter (not . isClassTyCon) 
82                    $ keep_tcs ++ new_tcs
83
84       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
85         do
86           defTyConPAs (zipLazy vect_tcs dfuns')
87           reprs     <- mapM tyConRepr vect_tcs
88           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
89           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
90
91           dfuns     <- sequence 
92                     $  zipWith5 buildTyConBindings
93                                orig_tcs
94                                vect_tcs
95                                repr_tcs
96                                pdata_tcs
97                                reprs
98
99           binds     <- takeHoisted
100           return (dfuns, binds, repr_tcs ++ pdata_tcs)
101
102       let all_new_tcs = new_tcs ++ inst_tcs
103
104       let new_env = extendTypeEnvList env
105                        (map ATyCon all_new_tcs
106                         ++ [ADataCon dc | tc <- all_new_tcs
107                                         , dc <- tyConDataCons tc])
108
109       return (new_env, map mkLocalFamInst inst_tcs, binds)
110   where
111     tycons = typeEnvTyCons env
112     groups = tyConGroups tycons
113
114     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
115
116
117
118 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
119 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
120  = do vectDataConWorkers orig_tc vect_tc pdata_tc
121       buildPADict vect_tc prepr_tc pdata_tc repr
122
123
124 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
125 vectDataConWorkers orig_tc vect_tc arr_tc
126  = do bs <- sequence
127           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
128           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
129                                  rep_tys
130                                  (inits rep_tys)
131                                  (tail $ tails rep_tys)
132       mapM_ (uncurry hoistBinding) bs
133  where
134     tyvars   = tyConTyVars vect_tc
135     var_tys  = mkTyVarTys tyvars
136     ty_args  = map Type var_tys
137     res_ty   = mkTyConApp vect_tc var_tys
138
139     cons     = tyConDataCons vect_tc
140     arity    = length cons
141     [arr_dc] = tyConDataCons arr_tc
142
143     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
144
145
146     mk_data_con con tys pre post
147       = liftM2 (,) (vect_data_con con)
148                    (lift_data_con tys pre post (mkDataConTag con))
149
150     sel_replicate len tag
151       | arity > 1 = do
152                       rep <- builtin (selReplicate arity)
153                       return [rep `mkApps` [len, tag]]
154
155       | otherwise = return []
156
157     vect_data_con con = return $ mkConApp con ty_args
158     lift_data_con tys pre_tys post_tys tag
159       = do
160           len  <- builtin liftingContext
161           args <- mapM (newLocalVar (fsLit "xs"))
162                   =<< mapM mkPDataType tys
163
164           sel  <- sel_replicate (Var len) tag
165
166           pre   <- mapM emptyPD (concat pre_tys)
167           post  <- mapM emptyPD (concat post_tys)
168
169           return . mkLams (len : args)
170                  . wrapFamInstBody arr_tc var_tys
171                  . mkConApp arr_dc
172                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
173
174     def_worker data_con arg_tys mk_body
175       = do
176           arity <- polyArity tyvars
177           body <- closedV
178                 . inBind orig_worker
179                 . polyAbstract tyvars $ \args ->
180                   liftM (mkLams (tyvars ++ args) . vectorised)
181                 $ buildClosures tyvars [] arg_tys res_ty mk_body
182
183           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
184           let vect_worker = raw_worker `setIdUnfolding`
185                               mkInlineRule body (Just arity)
186           defGlobalVar orig_worker vect_worker
187           return (vect_worker, body)
188       where
189         orig_worker = dataConWorkId data_con
190