PA dictionary generation
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
10 import CoreSyn
11 import CoreUtils
12 import DataCon
13 import TyCon
14 import Type
15 import TypeRep
16 import Coercion
17 import FamInstEnv        ( FamInst, mkLocalFamInst )
18 import InstEnv           ( Instance )
19 import OccName
20 import MkId
21 import BasicTypes        ( StrictnessMark(..), boolToRecFlag )
22 import Var               ( Var )
23 import Id                ( mkWildId )
24 import Name              ( Name )
25 import NameEnv
26 import TysWiredIn        ( intTy, intDataCon )
27 import TysPrim           ( intPrimTy )
28
29 import Unique
30 import UniqFM
31 import UniqSet
32 import Digraph           ( SCC(..), stronglyConnComp )
33
34 import Outputable
35
36 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_ )
37 import Data.List      ( inits, tails )
38
39 -- ----------------------------------------------------------------------------
40 -- Types
41
42 vectTyCon :: TyCon -> VM TyCon
43 vectTyCon tc
44   | isFunTyCon tc        = builtin closureTyCon
45   | isBoxedTupleTyCon tc = return tc
46   | isUnLiftedTyCon tc   = return tc
47   | otherwise = do
48                   r <- lookupTyCon tc
49                   case r of
50                     Just tc' -> return tc'
51
52                     -- FIXME: just for now
53                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
54
55 vectType :: Type -> VM Type
56 vectType ty | Just ty' <- coreView ty = vectType ty'
57 vectType (TyVarTy tv) = return $ TyVarTy tv
58 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
59 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
60 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
61                                              (mapM vectType [ty1,ty2])
62 vectType ty@(ForAllTy _ _)
63   = do
64       mdicts   <- mapM paDictArgType tyvars
65       mono_ty' <- vectType mono_ty
66       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
67   where
68     (tyvars, mono_ty) = splitForAllTys ty
69
70 vectType ty = pprPanic "vectType:" (ppr ty)
71
72 -- ----------------------------------------------------------------------------
73 -- Type definitions
74
75 type TyConGroup = ([TyCon], UniqSet TyCon)
76
77 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [Instance])
78 vectTypeEnv env
79   = do
80       cs <- readGEnv $ mk_map . global_tycons
81       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
82           keep_dcs             = concatMap tyConDataCons keep_tcs
83       zipWithM_ defTyCon   keep_tcs keep_tcs
84       zipWithM_ defDataCon keep_dcs keep_dcs
85       vect_tcs <- vectTyConDecls conv_tcs
86       parr_tcs1 <- zipWithM buildPArrayTyCon keep_tcs keep_tcs
87       parr_tcs2 <- zipWithM buildPArrayTyCon conv_tcs vect_tcs
88       let new_tcs = vect_tcs ++ parr_tcs1 ++ parr_tcs2
89
90       let new_env = extendTypeEnvList env
91                        (map ATyCon new_tcs
92                         ++ [ADataCon dc | tc <- new_tcs
93                                         , dc <- tyConDataCons tc])
94
95       return (new_env, map mkLocalFamInst (parr_tcs1 ++ parr_tcs2), [])
96   where
97     tycons = typeEnvTyCons env
98     groups = tyConGroups tycons
99
100     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
101
102     keep_tc tc = let dcs = tyConDataCons tc
103                  in
104                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
105
106
107 vectTyConDecls :: [TyCon] -> VM [TyCon]
108 vectTyConDecls tcs = fixV $ \tcs' ->
109   do
110     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
111     mapM vectTyConDecl tcs
112   where
113     lazy_zip [] _ = []
114     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
115
116 vectTyConDecl :: TyCon -> VM TyCon
117 vectTyConDecl tc
118   = do
119       name' <- cloneName mkVectTyConOcc name
120       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
121
122       return $ mkAlgTyCon name'
123                           kind
124                           tyvars
125                           []              -- no stupid theta
126                           rhs'
127                           []              -- no selector ids
128                           NoParentTyCon   -- FIXME
129                           rec_flag        -- FIXME: is this ok?
130                           False           -- FIXME: no generics
131                           False           -- not GADT syntax
132   where
133     name   = tyConName tc
134     kind   = tyConKind tc
135     tyvars = tyConTyVars tc
136     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
137
138 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
139 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
140                            , is_enum   = is_enum
141                            })
142   = do
143       data_cons' <- mapM vectDataCon data_cons
144       zipWithM_ defDataCon data_cons data_cons'
145       return $ DataTyCon { data_cons = data_cons'
146                          , is_enum   = is_enum
147                          }
148
149 vectDataCon :: DataCon -> VM DataCon
150 vectDataCon dc
151   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
152   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
153   | otherwise
154   = do
155       name'    <- cloneName mkVectDataConOcc name
156       tycon'   <- vectTyCon tycon
157       arg_tys  <- mapM vectType rep_arg_tys
158       wrk_name <- cloneName mkDataConWorkerOcc name'
159
160       let ids      = mkDataConIds (panic "vectDataCon: wrapper id")
161                                   wrk_name
162                                   data_con
163           data_con = mkDataCon name'
164                                False           -- not infix
165                                (map (const NotMarkedStrict) arg_tys)
166                                []              -- no labelled fields
167                                univ_tvs
168                                []              -- no existential tvs for now
169                                []              -- no eq spec for now
170                                []              -- no theta
171                                arg_tys
172                                tycon'
173                                []              -- no stupid theta
174                                ids
175       return data_con
176   where
177     name        = dataConName dc
178     univ_tvs    = dataConUnivTyVars dc
179     rep_arg_tys = dataConRepArgTys dc
180     tycon       = dataConTyCon dc
181
182 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
183 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
184   do
185     name'  <- cloneName mkPArrayTyConOcc orig_name
186     parent <- buildPArrayParentInfo orig_name vect_tc repr_tc
187     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
188
189     return $ mkAlgTyCon name'
190                         kind
191                         tyvars
192                         []              -- no stupid theta
193                         rhs
194                         []              -- no selector ids
195                         parent
196                         rec_flag        -- FIXME: is this ok?
197                         False           -- FIXME: no generics
198                         False           -- not GADT syntax
199   where
200     orig_name = tyConName orig_tc
201     name   = tyConName vect_tc
202     kind   = tyConKind vect_tc
203     tyvars = tyConTyVars vect_tc
204     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
205     
206
207 buildPArrayParentInfo :: Name -> TyCon -> TyCon -> VM TyConParent
208 buildPArrayParentInfo orig_name vect_tc repr_tc
209   = do
210       parray_tc <- builtin parrayTyCon
211       co_name <- cloneName mkInstTyCoOcc (tyConName repr_tc)
212
213       let inst_tys = [mkTyConApp vect_tc (map mkTyVarTy tyvars)]
214
215       return . FamilyTyCon parray_tc inst_tys
216              $ mkFamInstCoercion co_name
217                                  tyvars
218                                  parray_tc
219                                  inst_tys
220                                  repr_tc
221   where
222     tyvars = tyConTyVars vect_tc
223
224 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
225 buildPArrayTyConRhs orig_name vect_tc repr_tc
226   = do
227       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
228       return $ DataTyCon { data_cons = [data_con], is_enum = False }
229
230 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
231 buildPArrayDataCon orig_name vect_tc repr_tc
232   = do
233       dc_name  <- cloneName mkPArrayDataConOcc orig_name
234       shape_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this!
235       repr_tys <- mapM mkPArrayType types
236       wrk_name <- cloneName mkDataConWorkerOcc  dc_name
237       wrp_name <- cloneName mkDataConWrapperOcc dc_name
238
239       let ids      = mkDataConIds wrp_name wrk_name data_con
240           data_con = mkDataCon dc_name
241                                False
242                                (MarkedStrict : map (const NotMarkedStrict) repr_tys)
243                                []
244                                (tyConTyVars vect_tc)
245                                []
246                                []
247                                []
248                                (shape_ty : repr_tys)
249                                repr_tc
250                                []
251                                ids
252
253       return data_con
254   where
255     types = [ty | dc <- tyConDataCons vect_tc
256                 , ty <- dataConRepArgTys dc]
257
258 buildPADict :: Var -> TyCon -> TyCon -> VM [(Var, CoreExpr)]
259 buildPADict var vect_tc arr_tc
260   = localV . abstractOverTyVars (tyConTyVars arr_tc) $ \abstract ->
261     do
262       meth_binds <- mapM (mk_method abstract) paMethods
263       let meth_vars = map (Var . fst) meth_binds
264       meth_exprs <- mapM (`applyToTypes` arg_tys) meth_vars
265
266       pa_dc <- builtin paDictDataCon
267       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
268       return $ (var, dict) : meth_binds
269   where
270     tvs = tyConTyVars arr_tc
271     arg_tys = mkTyVarTys tvs
272
273     mk_method abstract (name, build)
274       = localV
275       $ do
276           body <- liftM abstract $ build vect_tc arr_tc
277           var <- newLocalVar name (exprType body)
278           return (var, mkInlineMe body)
279           
280 paMethods = [(FSLIT("lengthPA"),    buildLengthPA),
281              (FSLIT("replicatePA"), buildReplicatePA)]
282
283 buildLengthPA :: TyCon -> TyCon -> VM CoreExpr
284 buildLengthPA _ arr_tc
285   = do
286       arg   <- newLocalVar FSLIT("xs") arg_ty
287       shape <- newLocalVar FSLIT("sel") shape_ty
288       body  <- lengthPA (Var shape)
289       return . Lam arg
290              $ Case (Var arg) (mkWildId arg_ty) intPrimTy
291                     [(DataAlt repr_dc, shape : map mkWildId repr_tys, body)]
292   where
293     arg_ty = mkTyConApp arr_tc . mkTyVarTys $ tyConTyVars arr_tc
294     [repr_dc] = tyConDataCons arr_tc
295     shape_ty : repr_tys = dataConRepArgTys repr_dc
296
297
298 -- data T = C0 t1 ... tm
299 --          ...
300 --          Ck u1 ... un
301 --
302 -- data [:T:] = A ![:Int:] [:t1:] ... [:un:]
303 --
304 -- replicatePA :: Int# -> T -> [:T:]
305 -- replicatePA n# t
306 --   = let c = case t of
307 --               C0 _ ... _ -> 0
308 --               ...
309 --               Ck _ ... _ -> k
310 --
311 --         xs1 = case t of
312 --                 C0 x1 _ ... _ -> replicatePA @t1 n# x1
313 --                 _             -> emptyPA @t1
314 --
315 --         ...
316 --
317 --         ysn = case t of
318 --                 Ck _ ... _ yn -> replicatePA @un n# yn
319 --                 _             -> emptyPA @un
320 --     in
321 --     A (replicatePA @Int n# c) xs1 ... ysn
322 --
323 --
324
325 buildReplicatePA :: TyCon -> TyCon -> VM CoreExpr
326 buildReplicatePA vect_tc arr_tc
327   = do
328       len_var <- newLocalVar FSLIT("n") intPrimTy
329       val_var <- newLocalVar FSLIT("x") val_ty
330
331       let len = Var len_var
332           val = Var val_var
333
334       shape <- replicatePA len (ctr_num val)
335       reprs <- liftM concat $ mapM (mk_comp_arrs len val) vect_dcs
336       
337       return . mkLams [len_var, val_var]
338              $ mkConApp arr_dc (map (Type . TyVarTy) (tyConTyVars arr_tc) ++ (shape : reprs))
339   where
340     val_ty = mkTyConApp vect_tc . mkTyVarTys $ tyConTyVars arr_tc
341     wild   = mkWildId val_ty
342     vect_dcs = tyConDataCons vect_tc
343     [arr_dc] = tyConDataCons arr_tc
344
345     ctr_num val = Case val wild intTy (zipWith ctr_num_alt vect_dcs [0..])
346     ctr_num_alt dc i = (DataAlt dc, map mkWildId (dataConRepArgTys dc),
347                                     mkConApp intDataCon [mkIntLitInt i])
348
349
350     mk_comp_arrs len val dc = let tys = dataConRepArgTys dc
351                                   wilds = map mkWildId tys
352                               in
353                               sequence (zipWith3 (mk_comp_arr len val dc)
354                                        tys (inits wilds) (tails wilds))
355
356     mk_comp_arr len val dc ty pre (_:post)
357       = do
358           var   <- newLocalVar FSLIT("x") ty
359           rep   <- replicatePA len (Var var)
360           empty <- emptyPA ty
361           arr_ty <- mkPArrayType ty
362
363           return $ Case val wild arr_ty
364                      [(DataAlt dc, pre ++ (var : post), rep), (DEFAULT, [], empty)]
365
366 -- | Split the given tycons into two sets depending on whether they have to be
367 -- converted (first list) or not (second list). The first argument contains
368 -- information about the conversion status of external tycons:
369 -- 
370 --   * tycons which have converted versions are mapped to True
371 --   * tycons which are not changed by vectorisation are mapped to False
372 --   * tycons which can't be converted are not elements of the map
373 --
374 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
375 classifyTyCons = classify [] []
376   where
377     classify conv keep cs [] = (conv, keep)
378     classify conv keep cs ((tcs, ds) : rs)
379       | can_convert && must_convert
380         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
381       | can_convert
382         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
383       | otherwise
384         = classify conv keep cs rs
385       where
386         refs = ds `delListFromUniqSet` tcs
387
388         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
389         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
390
391         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
392     
393 -- | Compute mutually recursive groups of tycons in topological order
394 --
395 tyConGroups :: [TyCon] -> [TyConGroup]
396 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
397   where
398     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
399                                 , let ds = tyConsOfTyCon tc]
400
401     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
402     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
403       where
404         (tcs, dss) = unzip els
405
406 tyConsOfTyCon :: TyCon -> UniqSet TyCon
407 tyConsOfTyCon 
408   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
409
410 tyConsOfType :: Type -> UniqSet TyCon
411 tyConsOfType ty
412   | Just ty' <- coreView ty    = tyConsOfType ty'
413 tyConsOfType (TyVarTy v)       = emptyUniqSet
414 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
415   where
416     extend | isUnLiftedTyCon tc
417            || isTupleTyCon   tc = id
418
419            | otherwise          = (`addOneToUniqSet` tc)
420
421 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
422 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
423                                  `addOneToUniqSet` funTyCon
424 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
425 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
426
427 tyConsOfTypes :: [Type] -> UniqSet TyCon
428 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
429