Use dataConTag in flattened array representation
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv,
2                    PAInstance, buildPADict )
3 where
4
5 #include "HsVersions.h"
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import DataCon
15 import TyCon
16 import Type
17 import TypeRep
18 import Coercion
19 import FamInstEnv        ( FamInst, mkLocalFamInst )
20 import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
21 import OccName
22 import MkId
23 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
24 import Var               ( Var )
25 import Id                ( mkWildId )
26 import Name              ( Name, getOccName )
27 import NameEnv
28 import TysWiredIn        ( intTy, intDataCon )
29 import TysPrim           ( intPrimTy )
30
31 import Unique
32 import UniqFM
33 import UniqSet
34 import Digraph           ( SCC(..), stronglyConnComp )
35
36 import Outputable
37
38 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_ )
39 import Data.List      ( inits, tails, zipWith4 )
40
41 -- ----------------------------------------------------------------------------
42 -- Types
43
44 vectTyCon :: TyCon -> VM TyCon
45 vectTyCon tc
46   | isFunTyCon tc        = builtin closureTyCon
47   | isBoxedTupleTyCon tc = return tc
48   | isUnLiftedTyCon tc   = return tc
49   | otherwise = do
50                   r <- lookupTyCon tc
51                   case r of
52                     Just tc' -> return tc'
53
54                     -- FIXME: just for now
55                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
56
57 vectType :: Type -> VM Type
58 vectType ty | Just ty' <- coreView ty = vectType ty'
59 vectType (TyVarTy tv) = return $ TyVarTy tv
60 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
61 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
62 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
63                                              (mapM vectType [ty1,ty2])
64 vectType ty@(ForAllTy _ _)
65   = do
66       mdicts   <- mapM paDictArgType tyvars
67       mono_ty' <- vectType mono_ty
68       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
69   where
70     (tyvars, mono_ty) = splitForAllTys ty
71
72 vectType ty = pprPanic "vectType:" (ppr ty)
73
74 -- ----------------------------------------------------------------------------
75 -- Type definitions
76
77 type TyConGroup = ([TyCon], UniqSet TyCon)
78
79 data PAInstance = PAInstance {
80                     painstDFun      :: Var
81                   , painstOrigTyCon :: TyCon
82                   , painstVectTyCon :: TyCon
83                   , painstArrTyCon  :: TyCon
84                   }
85
86 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
87 vectTypeEnv env
88   = do
89       cs <- readGEnv $ mk_map . global_tycons
90       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
91           keep_dcs             = concatMap tyConDataCons keep_tcs
92       zipWithM_ defTyCon   keep_tcs keep_tcs
93       zipWithM_ defDataCon keep_dcs keep_dcs
94       new_tcs <- vectTyConDecls conv_tcs
95
96       let orig_tcs = keep_tcs ++ conv_tcs
97           vect_tcs  = keep_tcs ++ new_tcs
98
99       parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
100       dfuns    <- mapM mkPADFun vect_tcs
101       defTyConPAs (zip vect_tcs dfuns)
102       binds    <- sequence (zipWith4 buildTyConBindings orig_tcs vect_tcs parr_tcs dfuns)
103       
104       let all_new_tcs = new_tcs ++ parr_tcs
105
106       let new_env = extendTypeEnvList env
107                        (map ATyCon all_new_tcs
108                         ++ [ADataCon dc | tc <- all_new_tcs
109                                         , dc <- tyConDataCons tc])
110
111       return (new_env, map mkLocalFamInst parr_tcs, concat binds)
112   where
113     tycons = typeEnvTyCons env
114     groups = tyConGroups tycons
115
116     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
117
118     keep_tc tc = let dcs = tyConDataCons tc
119                  in
120                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
121
122
123 vectTyConDecls :: [TyCon] -> VM [TyCon]
124 vectTyConDecls tcs = fixV $ \tcs' ->
125   do
126     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
127     mapM vectTyConDecl tcs
128   where
129     lazy_zip [] _ = []
130     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
131
132 vectTyConDecl :: TyCon -> VM TyCon
133 vectTyConDecl tc
134   = do
135       name' <- cloneName mkVectTyConOcc name
136       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
137
138       return $ mkAlgTyCon name'
139                           kind
140                           tyvars
141                           []              -- no stupid theta
142                           rhs'
143                           []              -- no selector ids
144                           NoParentTyCon   -- FIXME
145                           rec_flag        -- FIXME: is this ok?
146                           False           -- FIXME: no generics
147                           False           -- not GADT syntax
148   where
149     name   = tyConName tc
150     kind   = tyConKind tc
151     tyvars = tyConTyVars tc
152     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
153
154 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
155 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
156                            , is_enum   = is_enum
157                            })
158   = do
159       data_cons' <- mapM vectDataCon data_cons
160       zipWithM_ defDataCon data_cons data_cons'
161       return $ DataTyCon { data_cons = data_cons'
162                          , is_enum   = is_enum
163                          }
164
165 vectDataCon :: DataCon -> VM DataCon
166 vectDataCon dc
167   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
168   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
169   | otherwise
170   = do
171       name'    <- cloneName mkVectDataConOcc name
172       tycon'   <- vectTyCon tycon
173       arg_tys  <- mapM vectType rep_arg_tys
174       wrk_name <- cloneName mkDataConWorkerOcc name'
175
176       let ids      = mkDataConIds (panic "vectDataCon: wrapper id")
177                                   wrk_name
178                                   data_con
179           data_con = mkDataCon name'
180                                False           -- not infix
181                                (map (const NotMarkedStrict) arg_tys)
182                                []              -- no labelled fields
183                                univ_tvs
184                                []              -- no existential tvs for now
185                                []              -- no eq spec for now
186                                []              -- no theta
187                                arg_tys
188                                tycon'
189                                []              -- no stupid theta
190                                ids
191       return data_con
192   where
193     name        = dataConName dc
194     univ_tvs    = dataConUnivTyVars dc
195     rep_arg_tys = dataConRepArgTys dc
196     tycon       = dataConTyCon dc
197
198 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
199 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
200   do
201     name'  <- cloneName mkPArrayTyConOcc orig_name
202     parent <- buildPArrayParentInfo orig_name vect_tc repr_tc
203     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
204
205     return $ mkAlgTyCon name'
206                         kind
207                         tyvars
208                         []              -- no stupid theta
209                         rhs
210                         []              -- no selector ids
211                         parent
212                         rec_flag        -- FIXME: is this ok?
213                         False           -- FIXME: no generics
214                         False           -- not GADT syntax
215   where
216     orig_name = tyConName orig_tc
217     name   = tyConName vect_tc
218     kind   = tyConKind vect_tc
219     tyvars = tyConTyVars vect_tc
220     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
221     
222
223 buildPArrayParentInfo :: Name -> TyCon -> TyCon -> VM TyConParent
224 buildPArrayParentInfo orig_name vect_tc repr_tc
225   = do
226       parray_tc <- builtin parrayTyCon
227       co_name <- cloneName mkInstTyCoOcc (tyConName repr_tc)
228
229       let inst_tys = [mkTyConApp vect_tc (map mkTyVarTy tyvars)]
230
231       return . FamilyTyCon parray_tc inst_tys
232              $ mkFamInstCoercion co_name
233                                  tyvars
234                                  parray_tc
235                                  inst_tys
236                                  repr_tc
237   where
238     tyvars = tyConTyVars vect_tc
239
240 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
241 buildPArrayTyConRhs orig_name vect_tc repr_tc
242   = do
243       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
244       return $ DataTyCon { data_cons = [data_con], is_enum = False }
245
246 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
247 buildPArrayDataCon orig_name vect_tc repr_tc
248   = do
249       dc_name  <- cloneName mkPArrayDataConOcc orig_name
250       shape    <- tyConShape vect_tc
251       repr_tys <- mapM mkPArrayType types
252       wrk_name <- cloneName mkDataConWorkerOcc  dc_name
253       wrp_name <- cloneName mkDataConWrapperOcc dc_name
254
255       let ids      = mkDataConIds wrp_name wrk_name data_con
256           data_con = mkDataCon dc_name
257                                False
258                                (shapeStrictness shape ++ map (const NotMarkedStrict) repr_tys)
259                                []
260                                (tyConTyVars vect_tc)
261                                []
262                                []
263                                []
264                                (shapeReprTys shape ++ repr_tys)
265                                repr_tc
266                                []
267                                ids
268
269       return data_con
270   where
271     types = [ty | dc <- tyConDataCons vect_tc
272                 , ty <- dataConRepArgTys dc]
273
274 mkPADFun :: TyCon -> VM Var
275 mkPADFun vect_tc
276   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
277
278 data Shape = Shape {
279                shapeReprTys    :: [Type]
280              , shapeStrictness :: [StrictnessMark]
281              , shapeLength     :: [CoreExpr] -> VM CoreExpr
282              , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
283              }
284
285 tyConShape :: TyCon -> VM Shape
286 tyConShape vect_tc
287   | isProductTyCon vect_tc
288   = return $ Shape {
289                 shapeReprTys    = [intPrimTy]
290               , shapeStrictness = [NotMarkedStrict]
291               , shapeLength     = \[len] -> return len
292               , shapeReplicate  = \len _ -> return [len]
293               }
294
295   | otherwise
296   = do
297       repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
298       return $ Shape {
299                  shapeReprTys    = [repr_ty]
300                , shapeStrictness = [MarkedStrict]
301                , shapeLength     = \[sel] -> lengthPA sel
302                , shapeReplicate  = \len n -> do
303                                                e <- replicatePA len n
304                                                return [e]
305                }
306
307 buildTyConBindings :: TyCon -> TyCon -> TyCon -> Var -> VM [(Var, CoreExpr)]
308 buildTyConBindings orig_tc vect_tc arr_tc dfun
309   = do
310       shape <- tyConShape vect_tc
311       sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
312                           orig_dcs
313                           vect_dcs
314                           (inits repr_tys)
315                           (tails repr_tys))
316       dict <- buildPADict shape vect_tc arr_tc dfun
317       binds <- takeHoisted
318       return $ (dfun, dict) : binds
319   where
320     orig_dcs = tyConDataCons orig_tc
321     vect_dcs = tyConDataCons vect_tc
322     [arr_dc] = tyConDataCons arr_tc
323
324     repr_tys = map dataConRepArgTys vect_dcs
325
326 vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
327                   -> DataCon -> DataCon -> [[Type]] -> [[Type]]
328                   -> VM ()
329 vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
330   = do
331       clo <- closedV
332            . inBind orig_worker
333            . polyAbstract tvs $ \abstract ->
334              liftM (abstract . vectorised)
335            $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
336
337       worker <- cloneId mkVectOcc orig_worker (exprType clo)
338       hoistBinding worker clo
339       defGlobalVar orig_worker worker
340       return ()
341   where
342     tvs     = tyConTyVars vect_tc
343     arg_tys = mkTyVarTys tvs
344     res_ty  = mkTyConApp vect_tc arg_tys
345
346     orig_worker = dataConWorkId orig_dc
347
348     mk_vect = return . mkConApp vect_dc $ map Type arg_tys
349     mk_lift = do
350                 len     <- newLocalVar FSLIT("n") intPrimTy
351                 arr_tys <- mapM mkPArrayType dc_tys
352                 args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
353                 shapes  <- shapeReplicate shape
354                                           (Var len)
355                                           (mkIntLitInt $ dataConTag vect_dc)
356                 
357                 empty_pre  <- mapM emptyPA (concat pre)
358                 empty_post <- mapM emptyPA (concat post)
359
360                 return . mkLams (len : args)
361                        . wrapFamInstBody arr_tc arg_tys
362                        . mkConApp arr_dc
363                        $ map Type arg_tys ++ shapes
364                                           ++ empty_pre
365                                           ++ map Var args
366                                           ++ empty_post
367
368 buildPADict :: Shape -> TyCon -> TyCon -> Var -> VM CoreExpr
369 buildPADict shape vect_tc arr_tc dfun
370   = polyAbstract tvs $ \abstract ->
371     do
372       meth_binds <- mapM (mk_method shape) paMethods
373       let meth_exprs = map (Var . fst) meth_binds
374
375       pa_dc <- builtin paDataCon
376       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
377           body = Let (Rec meth_binds) dict
378       return . mkInlineMe $ abstract body
379   where
380     tvs = tyConTyVars arr_tc
381     arg_tys = mkTyVarTys tvs
382
383     mk_method shape (name, build)
384       = localV
385       $ do
386           body <- build shape vect_tc arr_tc
387           var  <- newLocalVar name (exprType body)
388           return (var, mkInlineMe body)
389           
390 paMethods = [(FSLIT("lengthPA"),    buildLengthPA),
391              (FSLIT("replicatePA"), buildReplicatePA)]
392
393 buildLengthPA :: Shape -> TyCon -> TyCon -> VM CoreExpr
394 buildLengthPA shape vect_tc arr_tc
395   = do
396       parr_ty <- mkPArrayType (mkTyConApp vect_tc arg_tys)
397       arg    <- newLocalVar FSLIT("xs") parr_ty
398       shapes <- mapM (newLocalVar FSLIT("sh")) shape_tys
399       wilds  <- mapM newDummyVar repr_tys
400       let scrut    = unwrapFamInstScrut arr_tc arg_tys (Var arg)
401           scrut_ty = exprType scrut
402
403       body <- shapeLength shape (map Var shapes)
404
405       return . Lam arg
406              $ Case scrut (mkWildId scrut_ty) intPrimTy
407                     [(DataAlt repr_dc, shapes ++ wilds, body)]
408   where
409     arg_tys = mkTyVarTys $ tyConTyVars arr_tc
410     [repr_dc] = tyConDataCons arr_tc
411     
412     shape_tys = shapeReprTys shape
413     repr_tys  = drop (length shape_tys) (dataConRepArgTys repr_dc)
414
415 -- data T = C0 t1 ... tm
416 --          ...
417 --          Ck u1 ... un
418 --
419 -- data [:T:] = A ![:Int:] [:t1:] ... [:un:]
420 --
421 -- replicatePA :: Int# -> T -> [:T:]
422 -- replicatePA n# t
423 --   = let c = case t of
424 --               C0 _ ... _ -> 0
425 --               ...
426 --               Ck _ ... _ -> k
427 --
428 --         xs1 = case t of
429 --                 C0 x1 _ ... _ -> replicatePA @t1 n# x1
430 --                 _             -> emptyPA @t1
431 --
432 --         ...
433 --
434 --         ysn = case t of
435 --                 Ck _ ... _ yn -> replicatePA @un n# yn
436 --                 _             -> emptyPA @un
437 --     in
438 --     A (replicatePA @Int n# c) xs1 ... ysn
439 --
440 --
441
442 buildReplicatePA :: Shape -> TyCon -> TyCon -> VM CoreExpr
443 buildReplicatePA shape vect_tc arr_tc
444   = do
445       len_var <- newLocalVar FSLIT("n") intPrimTy
446       val_var <- newLocalVar FSLIT("x") val_ty
447
448       let len = Var len_var
449           val = Var val_var
450
451       shape_reprs <- shapeReplicate shape len (ctr_num val)
452       reprs <- liftM concat $ mapM (mk_comp_arrs len val) vect_dcs
453
454       return . mkLams [len_var, val_var]
455              . wrapFamInstBody arr_tc arg_tys
456              $ mkConApp arr_dc (map Type arg_tys ++ shape_reprs ++ reprs)
457   where
458     arg_tys = mkTyVarTys (tyConTyVars arr_tc)
459     val_ty  = mkTyConApp vect_tc arg_tys
460     wild    = mkWildId val_ty
461     vect_dcs = tyConDataCons vect_tc
462     [arr_dc] = tyConDataCons arr_tc
463
464     ctr_num val = Case val wild intTy (zipWith ctr_num_alt vect_dcs [0..])
465     ctr_num_alt dc i = (DataAlt dc, map mkWildId (dataConRepArgTys dc),
466                                     mkConApp intDataCon [mkIntLitInt i])
467
468
469     mk_comp_arrs len val dc = let tys = dataConRepArgTys dc
470                                   wilds = map mkWildId tys
471                               in
472                               sequence (zipWith3 (mk_comp_arr len val dc)
473                                        tys (inits wilds) (tails wilds))
474
475     mk_comp_arr len val dc ty pre (_:post)
476       = do
477           var   <- newLocalVar FSLIT("x") ty
478           rep   <- replicatePA len (Var var)
479           empty <- emptyPA ty
480           arr_ty <- mkPArrayType ty
481
482           return $ Case val wild arr_ty
483                      [(DEFAULT, [], empty), (DataAlt dc, pre ++ (var : post), rep)]
484
485 -- | Split the given tycons into two sets depending on whether they have to be
486 -- converted (first list) or not (second list). The first argument contains
487 -- information about the conversion status of external tycons:
488 -- 
489 --   * tycons which have converted versions are mapped to True
490 --   * tycons which are not changed by vectorisation are mapped to False
491 --   * tycons which can't be converted are not elements of the map
492 --
493 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
494 classifyTyCons = classify [] []
495   where
496     classify conv keep cs [] = (conv, keep)
497     classify conv keep cs ((tcs, ds) : rs)
498       | can_convert && must_convert
499         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
500       | can_convert
501         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
502       | otherwise
503         = classify conv keep cs rs
504       where
505         refs = ds `delListFromUniqSet` tcs
506
507         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
508         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
509
510         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
511     
512 -- | Compute mutually recursive groups of tycons in topological order
513 --
514 tyConGroups :: [TyCon] -> [TyConGroup]
515 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
516   where
517     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
518                                 , let ds = tyConsOfTyCon tc]
519
520     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
521     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
522       where
523         (tcs, dss) = unzip els
524
525 tyConsOfTyCon :: TyCon -> UniqSet TyCon
526 tyConsOfTyCon 
527   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
528
529 tyConsOfType :: Type -> UniqSet TyCon
530 tyConsOfType ty
531   | Just ty' <- coreView ty    = tyConsOfType ty'
532 tyConsOfType (TyVarTy v)       = emptyUniqSet
533 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
534   where
535     extend | isUnLiftedTyCon tc
536            || isTupleTyCon   tc = id
537
538            | otherwise          = (`addOneToUniqSet` tc)
539
540 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
541 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
542                                  `addOneToUniqSet` funTyCon
543 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
544 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
545
546 tyConsOfTypes :: [Type] -> UniqSet TyCon
547 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
548