Sort all the PADict/PData/PRDict/PRepr stuff into their own modules
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( 
4         vectTyCon,
5         vectAndLiftType,
6         vectType,
7         vectTypeEnv,
8         buildPADict,
9         fromVect
10 )
11 where
12 import VectUtils
13 import Vectorise.Env
14 import Vectorise.Convert
15 import Vectorise.Vect
16 import Vectorise.Monad
17 import Vectorise.Builtins
18 import Vectorise.Type.Type
19 import Vectorise.Type.TyConDecl
20 import Vectorise.Type.Classify
21 import Vectorise.Type.PADict
22 import Vectorise.Type.PData
23 import Vectorise.Type.PRepr
24 import Vectorise.Type.Repr
25 import Vectorise.Utils.Closure
26 import Vectorise.Utils.Hoisting
27
28 import HscTypes
29 import CoreSyn
30 import CoreUtils
31 import CoreUnfold
32 import DataCon
33 import TyCon
34 import Type
35 import FamInstEnv
36 import OccName
37 import Id
38 import MkId
39 import Var
40 import NameEnv
41
42 import Unique
43 import UniqFM
44 import Util
45 import Outputable
46 import FastString
47 import MonadUtils
48 import Control.Monad
49 import Data.List
50
51 debug           = False
52 dtrace s x      = if debug then pprTrace "VectType" s x else x
53
54
55 -- | Vectorise a type environment.
56 --   The type environment contains all the type things defined in a module.
57 vectTypeEnv 
58         :: TypeEnv
59         -> VM ( TypeEnv                 -- Vectorised type environment.
60               , [FamInst]               -- New type family instances.
61               , [(Var, CoreExpr)])      -- New top level bindings.
62         
63 vectTypeEnv env
64  = dtrace (ppr env)
65  $ do
66       cs <- readGEnv $ mk_map . global_tycons
67
68       -- Split the list of TyCons into the ones we have to vectorise vs the
69       -- ones we can pass through unchanged. We also pass through algebraic 
70       -- types that use non Haskell98 features, as we don't handle those.
71       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
72           keep_dcs             = concatMap tyConDataCons keep_tcs
73
74       zipWithM_ defTyCon   keep_tcs keep_tcs
75       zipWithM_ defDataCon keep_dcs keep_dcs
76
77       new_tcs <- vectTyConDecls conv_tcs
78
79       let orig_tcs = keep_tcs ++ conv_tcs
80
81       -- We don't need to make new representation types for dictionary
82       -- constructors. The constructors are always fully applied, and we don't 
83       -- need to lift them to arrays as a dictionary of a particular type
84       -- always has the same value.
85       let vect_tcs = filter (not . isClassTyCon) 
86                    $ keep_tcs ++ new_tcs
87
88       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
89         do
90           defTyConPAs (zipLazy vect_tcs dfuns')
91           reprs     <- mapM tyConRepr vect_tcs
92           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
93           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
94
95           dfuns     <- sequence 
96                     $  zipWith5 buildTyConBindings
97                                orig_tcs
98                                vect_tcs
99                                repr_tcs
100                                pdata_tcs
101                                reprs
102
103           binds     <- takeHoisted
104           return (dfuns, binds, repr_tcs ++ pdata_tcs)
105
106       let all_new_tcs = new_tcs ++ inst_tcs
107
108       let new_env = extendTypeEnvList env
109                        (map ATyCon all_new_tcs
110                         ++ [ADataCon dc | tc <- all_new_tcs
111                                         , dc <- tyConDataCons tc])
112
113       return (new_env, map mkLocalFamInst inst_tcs, binds)
114   where
115     tycons = typeEnvTyCons env
116     groups = tyConGroups tycons
117
118     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
119
120
121
122 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
123 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
124  = do vectDataConWorkers orig_tc vect_tc pdata_tc
125       buildPADict vect_tc prepr_tc pdata_tc repr
126
127
128 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
129 vectDataConWorkers orig_tc vect_tc arr_tc
130  = do bs <- sequence
131           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
132           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
133                                  rep_tys
134                                  (inits rep_tys)
135                                  (tail $ tails rep_tys)
136       mapM_ (uncurry hoistBinding) bs
137  where
138     tyvars   = tyConTyVars vect_tc
139     var_tys  = mkTyVarTys tyvars
140     ty_args  = map Type var_tys
141     res_ty   = mkTyConApp vect_tc var_tys
142
143     cons     = tyConDataCons vect_tc
144     arity    = length cons
145     [arr_dc] = tyConDataCons arr_tc
146
147     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
148
149
150     mk_data_con con tys pre post
151       = liftM2 (,) (vect_data_con con)
152                    (lift_data_con tys pre post (mkDataConTag con))
153
154     sel_replicate len tag
155       | arity > 1 = do
156                       rep <- builtin (selReplicate arity)
157                       return [rep `mkApps` [len, tag]]
158
159       | otherwise = return []
160
161     vect_data_con con = return $ mkConApp con ty_args
162     lift_data_con tys pre_tys post_tys tag
163       = do
164           len  <- builtin liftingContext
165           args <- mapM (newLocalVar (fsLit "xs"))
166                   =<< mapM mkPDataType tys
167
168           sel  <- sel_replicate (Var len) tag
169
170           pre   <- mapM emptyPD (concat pre_tys)
171           post  <- mapM emptyPD (concat post_tys)
172
173           return . mkLams (len : args)
174                  . wrapFamInstBody arr_tc var_tys
175                  . mkConApp arr_dc
176                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
177
178     def_worker data_con arg_tys mk_body
179       = do
180           arity <- polyArity tyvars
181           body <- closedV
182                 . inBind orig_worker
183                 . polyAbstract tyvars $ \args ->
184                   liftM (mkLams (tyvars ++ args) . vectorised)
185                 $ buildClosures tyvars [] arg_tys res_ty mk_body
186
187           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
188           let vect_worker = raw_worker `setIdUnfolding`
189                               mkInlineRule body (Just arity)
190           defGlobalVar orig_worker vect_worker
191           return (vect_worker, body)
192       where
193         orig_worker = dataConWorkId data_con
194