Rewrite generation of PA dictionaries
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv,
2                    PAInstance, buildPADict )
3 where
4
5 #include "HsVersions.h"
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import BuildTyCl
15 import DataCon
16 import TyCon
17 import Type
18 import TypeRep
19 import Coercion
20 import FamInstEnv        ( FamInst, mkLocalFamInst )
21 import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
25 import Var               ( Var )
26 import Id                ( mkWildId )
27 import Name              ( Name, getOccName )
28 import NameEnv
29 import TysWiredIn        ( unitTy, unitTyCon, intTy, intDataCon, unitDataConId )
30 import TysPrim           ( intPrimTy )
31
32 import Unique
33 import UniqFM
34 import UniqSet
35 import Digraph           ( SCC(..), stronglyConnComp )
36
37 import Outputable
38
39 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
40 import Data.List      ( inits, tails, zipWith4, zipWith5 )
41
42 -- ----------------------------------------------------------------------------
43 -- Types
44
45 vectTyCon :: TyCon -> VM TyCon
46 vectTyCon tc
47   | isFunTyCon tc        = builtin closureTyCon
48   | isBoxedTupleTyCon tc = return tc
49   | isUnLiftedTyCon tc   = return tc
50   | otherwise = do
51                   r <- lookupTyCon tc
52                   case r of
53                     Just tc' -> return tc'
54
55                     -- FIXME: just for now
56                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
57
58 vectType :: Type -> VM Type
59 vectType ty | Just ty' <- coreView ty = vectType ty'
60 vectType (TyVarTy tv) = return $ TyVarTy tv
61 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
62 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
63 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
64                                              (mapM vectType [ty1,ty2])
65 vectType ty@(ForAllTy _ _)
66   = do
67       mdicts   <- mapM paDictArgType tyvars
68       mono_ty' <- vectType mono_ty
69       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
70   where
71     (tyvars, mono_ty) = splitForAllTys ty
72
73 vectType ty = pprPanic "vectType:" (ppr ty)
74
75 -- ----------------------------------------------------------------------------
76 -- Type definitions
77
78 type TyConGroup = ([TyCon], UniqSet TyCon)
79
80 data PAInstance = PAInstance {
81                     painstDFun      :: Var
82                   , painstOrigTyCon :: TyCon
83                   , painstVectTyCon :: TyCon
84                   , painstArrTyCon  :: TyCon
85                   }
86
87 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
88 vectTypeEnv env
89   = do
90       cs <- readGEnv $ mk_map . global_tycons
91       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
92           keep_dcs             = concatMap tyConDataCons keep_tcs
93       zipWithM_ defTyCon   keep_tcs keep_tcs
94       zipWithM_ defDataCon keep_dcs keep_dcs
95       new_tcs <- vectTyConDecls conv_tcs
96
97       let orig_tcs = keep_tcs ++ conv_tcs
98           vect_tcs  = keep_tcs ++ new_tcs
99
100       repr_tcs <- zipWithM buildPReprTyCon   orig_tcs vect_tcs
101       parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
102       dfuns    <- mapM mkPADFun vect_tcs
103       defTyConPAs (zip vect_tcs dfuns)
104       binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
105                                                         vect_tcs
106                                                         repr_tcs
107                                                         parr_tcs
108                                                         dfuns)
109
110       let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs
111
112       let new_env = extendTypeEnvList env
113                        (map ATyCon all_new_tcs
114                         ++ [ADataCon dc | tc <- all_new_tcs
115                                         , dc <- tyConDataCons tc])
116
117       return (new_env, map mkLocalFamInst (repr_tcs ++ parr_tcs), concat binds)
118   where
119     tycons = typeEnvTyCons env
120     groups = tyConGroups tycons
121
122     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
123
124     keep_tc tc = let dcs = tyConDataCons tc
125                  in
126                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
127
128
129 vectTyConDecls :: [TyCon] -> VM [TyCon]
130 vectTyConDecls tcs = fixV $ \tcs' ->
131   do
132     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
133     mapM vectTyConDecl tcs
134   where
135     lazy_zip [] _ = []
136     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
137
138 vectTyConDecl :: TyCon -> VM TyCon
139 vectTyConDecl tc
140   = do
141       name' <- cloneName mkVectTyConOcc name
142       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
143
144       liftDs $ buildAlgTyCon name'
145                              tyvars
146                              []           -- no stupid theta
147                              rhs'
148                              rec_flag     -- FIXME: is this ok?
149                              False        -- FIXME: no generics
150                              False        -- not GADT syntax
151                              Nothing      -- not a family instance
152   where
153     name   = tyConName tc
154     tyvars = tyConTyVars tc
155     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
156
157 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
158 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
159                            , is_enum   = is_enum
160                            })
161   = do
162       data_cons' <- mapM vectDataCon data_cons
163       zipWithM_ defDataCon data_cons data_cons'
164       return $ DataTyCon { data_cons = data_cons'
165                          , is_enum   = is_enum
166                          }
167
168 vectDataCon :: DataCon -> VM DataCon
169 vectDataCon dc
170   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
171   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
172   | otherwise
173   = do
174       name'    <- cloneName mkVectDataConOcc name
175       tycon'   <- vectTyCon tycon
176       arg_tys  <- mapM vectType rep_arg_tys
177
178       liftDs $ buildDataCon name'
179                             False           -- not infix
180                             (map (const NotMarkedStrict) arg_tys)
181                             []              -- no labelled fields
182                             univ_tvs
183                             []              -- no existential tvs for now
184                             []              -- no eq spec for now
185                             []              -- no context
186                             arg_tys
187                             tycon'
188   where
189     name        = dataConName dc
190     univ_tvs    = dataConUnivTyVars dc
191     rep_arg_tys = dataConRepArgTys dc
192     tycon       = dataConTyCon dc
193
194 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
195 mk_fam_inst fam_tc arg_tc
196   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
197
198 buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
199 buildPReprTyCon orig_tc vect_tc
200   = do
201       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
202       rhs_ty   <- buildPReprType vect_tc
203       prepr_tc <- builtin preprTyCon
204       liftDs $ buildSynTyCon name
205                              tyvars
206                              (SynonymTyCon rhs_ty)
207                              (Just $ mk_fam_inst prepr_tc vect_tc)
208   where
209     tyvars = tyConTyVars vect_tc
210
211
212 data Repr = ProdRepr {
213               prod_components   :: [Type]
214             , prod_tycon        :: TyCon
215             , prod_data_con     :: DataCon
216             , prod_arr_tycon    :: TyCon
217             , prod_arr_data_con :: DataCon
218             }
219
220           | SumRepr {
221               sum_components    :: [Repr]
222             , sum_tycon         :: TyCon
223             , sum_arr_tycon     :: TyCon
224             , sum_arr_data_con  :: DataCon
225             }
226
227 mkProduct :: [Type] -> VM Repr
228 mkProduct tys
229   = do
230       tycon <- builtin (prodTyCon arity)
231       let [data_con] = tyConDataCons tycon
232
233       (arr_tycon, _) <- parrayReprTyCon $ mkTyConApp tycon tys
234       let [arr_data_con] = tyConDataCons arr_tycon
235
236       return $ ProdRepr {
237                  prod_components   = tys
238                , prod_tycon        = tycon
239                , prod_data_con     = data_con
240                , prod_arr_tycon    = arr_tycon
241                , prod_arr_data_con = arr_data_con
242                }
243   where
244     arity = length tys
245
246 mkSum :: [Repr] -> VM Repr
247 mkSum [repr] = return repr
248 mkSum reprs
249   = do
250       tycon <- builtin (sumTyCon arity)
251       (arr_tycon, _) <- parrayReprTyCon
252                       . mkTyConApp tycon
253                       $ map reprType reprs
254
255       let [arr_data_con] = tyConDataCons arr_tycon
256
257       return $ SumRepr {
258                  sum_components   = reprs
259                , sum_tycon        = tycon
260                , sum_arr_tycon    = arr_tycon
261                , sum_arr_data_con = arr_data_con
262                }
263   where
264     arity = length reprs
265
266 reprProducts :: Repr -> [Repr]
267 reprProducts (SumRepr { sum_components = rs }) = rs
268 reprProducts repr                              = [repr]
269
270 reprType :: Repr -> Type
271 reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
272   = mkTyConApp tycon tys
273 reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
274   = mkTyConApp tycon (map reprType reprs)
275
276 arrReprType :: Repr -> VM Type
277 arrReprType = mkPArrayType . reprType
278
279 reprTys :: Repr -> [[Type]]
280 reprTys (SumRepr { sum_components = prods }) = map prodTys prods
281 reprTys prod                                 = [prodTys prod]
282
283 prodTys (ProdRepr { prod_components = tys }) = tys
284
285 reprVars :: Repr -> VM [[Var]]
286 reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys
287
288 arrShapeTys :: Repr -> VM [Type]
289 arrShapeTys (SumRepr  {})
290   = do
291       uarr <- builtin uarrTyCon
292       return [intPrimTy, mkTyConApp uarr [intTy]]
293 arrShapeTys repr = return [intPrimTy]
294
295 arrShapeVars :: Repr -> VM [Var]
296 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
297
298 arrReprTys :: Repr -> VM [[Type]]
299 arrReprTys (SumRepr { sum_components = prods })
300   = mapM arrProdTys prods
301 arrReprTys prod
302   = do
303       tys <- arrProdTys prod
304       return [tys]
305
306 arrProdTys (ProdRepr { prod_components = tys })
307   = mapM mkPArrayType (mk_types tys)
308   where
309     mk_types []  = [unitTy]
310     mk_types tys = tys
311
312 arrReprVars :: Repr -> VM [[Var]]
313 arrReprVars repr
314   = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr
315
316 mkRepr :: TyCon -> VM Repr
317 mkRepr vect_tc
318   = mkSum
319   =<< mapM mkProduct (map dataConRepArgTys $ tyConDataCons vect_tc)
320
321 buildPReprType :: TyCon -> VM Type
322 buildPReprType = liftM reprType . mkRepr
323
324 buildToPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
325 buildToPRepr repr vect_tc prepr_tc _
326   = do
327       arg    <- newLocalVar FSLIT("x") arg_ty
328       result <- to_repr repr (Var arg)
329
330       return . Lam arg
331              . wrapFamInstBody prepr_tc var_tys
332              $ result
333   where
334     var_tys = mkTyVarTys $ tyConTyVars vect_tc
335     arg_ty  = mkTyConApp vect_tc var_tys
336     res_ty  = reprType repr
337
338     cons    = tyConDataCons vect_tc
339     [con]   = cons
340
341     to_repr (SumRepr { sum_components = prods
342                      , sum_tycon      = tycon })
343             expr
344       = do
345           (vars, bodies) <- mapAndUnzipM prod_alt prods
346           return . Case expr (mkWildId (exprType expr)) res_ty
347                  $ zipWith4 mk_alt cons vars (tyConDataCons tycon) bodies
348       where
349         mk_alt con vars sum_con body
350           = (DataAlt con, vars, mkConApp sum_con (ty_args ++ [body]))
351
352         ty_args = map (Type . reprType) prods
353
354     to_repr prod expr
355       = do
356           (vars, body) <- prod_alt prod
357           return $ Case expr (mkWildId (exprType expr)) res_ty
358                    [(DataAlt con, vars, body)]
359
360     prod_alt (ProdRepr { prod_components = tys
361                        , prod_data_con   = data_con })
362       = do
363           vars <- mapM (newLocalVar FSLIT("r")) tys
364           return (vars, mkConApp data_con (map Type tys ++ map Var vars))
365
366 buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
367 buildFromPRepr repr vect_tc prepr_tc _
368   = do
369       arg_ty <- mkPReprType res_ty
370       arg    <- newLocalVar FSLIT("x") arg_ty
371
372       liftM (Lam arg)
373            . from_repr repr
374            $ unwrapFamInstScrut prepr_tc var_tys (Var arg)
375   where
376     var_tys = mkTyVarTys $ tyConTyVars vect_tc
377     res_ty  = mkTyConApp vect_tc var_tys
378
379     cons    = map (`mkConApp` map Type var_tys) (tyConDataCons vect_tc)
380     [con]   = cons
381
382     from_repr repr@(SumRepr { sum_components = prods
383                             , sum_tycon      = tycon })
384               expr
385       = do
386           vars   <- mapM (newLocalVar FSLIT("x")) (map reprType prods)
387           bodies <- sequence . zipWith3 from_prod prods cons
388                              $ map Var vars
389           return . Case expr (mkWildId (reprType repr)) res_ty
390                  $ zipWith3 sum_alt (tyConDataCons tycon) vars bodies
391       where
392         sum_alt data_con var body = (DataAlt data_con, [var], body)
393
394     from_repr repr expr = from_prod repr con expr
395
396     from_prod prod@(ProdRepr { prod_components = tys
397                              , prod_data_con   = data_con })
398               con
399               expr
400       = do
401           vars <- mapM (newLocalVar FSLIT("y")) tys
402           return $ Case expr (mkWildId (reprType prod)) res_ty
403                    [(DataAlt data_con, vars, con `mkVarApps` vars)]
404
405 buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
406 buildToArrPRepr repr vect_tc prepr_tc arr_tc
407   = do
408       arg_ty     <- mkPArrayType el_ty
409       arg        <- newLocalVar FSLIT("xs") arg_ty
410
411       res_ty     <- mkPArrayType (reprType repr)
412
413       shape_vars <- arrShapeVars repr
414       repr_vars  <- arrReprVars  repr
415
416       parray_co  <- mkBuiltinCo parrayTyCon
417
418       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
419           co           = mkAppCoercion parray_co
420                        . mkSymCoercion
421                        $ mkTyConApp repr_co var_tys
422
423           scrut   = unwrapFamInstScrut arr_tc var_tys (Var arg)
424
425       result <- to_repr shape_vars repr_vars repr
426
427       return . Lam arg
428              . mkCoerce co
429              $ Case scrut (mkWildId (mkTyConApp arr_tc var_tys)) res_ty
430                [(DataAlt arr_dc, shape_vars ++ concat repr_vars, result)]
431   where
432     var_tys = mkTyVarTys $ tyConTyVars vect_tc
433     el_ty   = mkTyConApp vect_tc var_tys
434
435     [arr_dc] = tyConDataCons arr_tc
436
437     to_repr shape_vars@(len_var : _)
438             repr_vars
439             (SumRepr { sum_components   = prods
440                      , sum_arr_tycon    = tycon
441                      , sum_arr_data_con = data_con })
442       = do
443           exprs <- zipWithM (to_prod len_var) repr_vars prods
444
445           return . wrapFamInstBody tycon tys
446                  . mkConApp data_con
447                  $ map Type tys ++ map Var shape_vars ++ exprs
448       where
449         tys = map reprType prods
450
451     to_repr [len_var] [repr_vars] prod = to_prod len_var repr_vars prod
452
453     to_prod len_var
454             repr_vars
455             (ProdRepr { prod_components   = tys
456                       , prod_arr_tycon    = tycon
457                       , prod_arr_data_con = data_con })
458       = return . wrapFamInstBody tycon tys
459                . mkConApp data_con
460                $ map Type tys ++ map Var (len_var : repr_vars)
461
462 buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
463 buildFromArrPRepr repr vect_tc prepr_tc arr_tc
464   = do
465       arg_ty     <- mkPArrayType =<< mkPReprType el_ty
466       arg        <- newLocalVar FSLIT("xs") arg_ty
467
468       res_ty     <- mkPArrayType el_ty
469
470       shape_vars <- arrShapeVars repr
471       repr_vars  <- arrReprVars  repr
472
473       parray_co  <- mkBuiltinCo parrayTyCon
474
475       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
476           co           = mkAppCoercion parray_co
477                        $ mkTyConApp repr_co var_tys
478
479           scrut  = mkCoerce co (Var arg)
480
481           result = wrapFamInstBody arr_tc var_tys
482                  . mkConApp arr_dc
483                  $ map Type var_tys ++ map Var (shape_vars ++ concat repr_vars)
484
485       liftM (Lam arg)
486             (from_repr repr scrut shape_vars repr_vars res_ty result)
487   where
488     var_tys = mkTyVarTys $ tyConTyVars vect_tc
489     el_ty   = mkTyConApp vect_tc var_tys
490
491     [arr_dc] = tyConDataCons arr_tc
492
493     from_repr (SumRepr { sum_components   = prods
494                        , sum_arr_tycon    = tycon
495                        , sum_arr_data_con = data_con })
496               expr
497               shape_vars
498               repr_vars
499               res_ty
500               body
501       = do
502           vars <- mapM (newLocalVar FSLIT("xs")) =<< mapM arrReprType prods
503           result <- go prods repr_vars vars body
504
505           let scrut = unwrapFamInstScrut tycon ty_args expr
506           return . Case scrut (mkWildId scrut_ty) res_ty
507                  $ [(DataAlt data_con, shape_vars ++ vars, result)]
508       where
509         ty_args  = map reprType prods
510         scrut_ty = mkTyConApp tycon ty_args
511
512         go [] [] [] body = return body
513         go (prod : prods) (repr_vars : rss) (var : vars) body
514           = do
515               shape_vars <- mapM (newLocalVar FSLIT("s")) =<< arrShapeTys prod
516
517               from_prod prod (Var var) shape_vars repr_vars res_ty
518                 =<< go prods rss vars body
519
520     from_repr repr expr shape_vars [repr_vars] res_ty body
521       = from_prod repr expr shape_vars repr_vars res_ty body
522
523     from_prod prod@(ProdRepr { prod_components = tys
524                              , prod_arr_tycon  = tycon
525                              , prod_arr_data_con = data_con })
526               expr
527               shape_vars
528               repr_vars
529               res_ty
530               body
531       = do
532           let scrut    = unwrapFamInstScrut tycon tys expr
533               scrut_ty = mkTyConApp tycon tys
534           ty <- arrReprType prod
535
536           return $ Case scrut (mkWildId scrut_ty) res_ty
537                    [(DataAlt data_con, shape_vars ++ repr_vars, body)]
538
539 buildPRDictRepr :: Repr -> VM CoreExpr
540 buildPRDictRepr (ProdRepr {
541                    prod_components = tys
542                  , prod_tycon      = tycon
543                  })
544   = do
545       prs  <- mapM mkPR tys
546       dfun <- prDFunOfTyCon tycon
547       return $ dfun `mkTyApps` tys `mkApps` prs
548
549 buildPRDictRepr (SumRepr {
550                    sum_components = prods
551                  , sum_tycon      = tycon })
552   = do
553       prs  <- mapM buildPRDictRepr prods
554       dfun <- prDFunOfTyCon tycon
555       return $ dfun `mkTyApps` map reprType prods `mkApps` prs
556
557 buildPRDict :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
558 buildPRDict repr vect_tc prepr_tc _
559   = do
560       dict  <- buildPRDictRepr repr
561
562       pr_co <- mkBuiltinCo prTyCon
563       let co = mkAppCoercion pr_co
564              . mkSymCoercion
565              $ mkTyConApp arg_co var_tys
566
567       return $ mkCoerce co dict
568   where
569     var_tys = mkTyVarTys $ tyConTyVars vect_tc
570
571     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
572
573 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
574 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
575   do
576     name'  <- cloneName mkPArrayTyConOcc orig_name
577     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
578     parray <- builtin parrayTyCon
579
580     liftDs $ buildAlgTyCon name'
581                            tyvars
582                            []          -- no stupid theta
583                            rhs
584                            rec_flag    -- FIXME: is this ok?
585                            False       -- FIXME: no generics
586                            False       -- not GADT syntax
587                            (Just $ mk_fam_inst parray vect_tc)
588   where
589     orig_name = tyConName orig_tc
590     tyvars = tyConTyVars vect_tc
591     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
592
593
594 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
595 buildPArrayTyConRhs orig_name vect_tc repr_tc
596   = do
597       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
598       return $ DataTyCon { data_cons = [data_con], is_enum = False }
599
600 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
601 buildPArrayDataCon orig_name vect_tc repr_tc
602   = do
603       dc_name  <- cloneName mkPArrayDataConOcc orig_name
604       repr     <- mkRepr vect_tc
605
606       shape_tys <- arrShapeTys repr
607       repr_tys  <- arrReprTys  repr
608
609       let tys = shape_tys ++ concat repr_tys
610
611       liftDs $ buildDataCon dc_name
612                             False                  -- not infix
613                             (map (const NotMarkedStrict) tys)
614                             []                     -- no field labels
615                             (tyConTyVars vect_tc)
616                             []                     -- no existentials
617                             []                     -- no eq spec
618                             []                     -- no context
619                             tys
620                             repr_tc
621
622 mkPADFun :: TyCon -> VM Var
623 mkPADFun vect_tc
624   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
625
626 data Shape = Shape {
627                shapeReprTys    :: [Type]
628              , shapeStrictness :: [StrictnessMark]
629              , shapeLength     :: [CoreExpr] -> VM CoreExpr
630              , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
631              }
632
633 tyConShape :: TyCon -> VM Shape
634 tyConShape vect_tc
635   | isProductTyCon vect_tc
636   = return $ Shape {
637                 shapeReprTys    = [intPrimTy]
638               , shapeStrictness = [NotMarkedStrict]
639               , shapeLength     = \[len] -> return len
640               , shapeReplicate  = \len _ -> return [len]
641               }
642
643   | otherwise
644   = do
645       repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
646       return $ Shape {
647                  shapeReprTys    = [repr_ty]
648                , shapeStrictness = [MarkedStrict]
649                , shapeLength     = \[sel] -> lengthPA sel
650                , shapeReplicate  = \len n -> do
651                                                e <- replicatePA len n
652                                                return [e]
653                }
654
655 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
656                    -> VM [(Var, CoreExpr)]
657 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
658   = do
659       shape <- tyConShape vect_tc
660       repr  <- mkRepr vect_tc
661       sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
662                           orig_dcs
663                           vect_dcs
664                           (inits repr_tys)
665                           (tails repr_tys))
666       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
667       binds <- takeHoisted
668       return $ (dfun, dict) : binds
669   where
670     orig_dcs = tyConDataCons orig_tc
671     vect_dcs = tyConDataCons vect_tc
672     [arr_dc] = tyConDataCons arr_tc
673
674     repr_tys = map dataConRepArgTys vect_dcs
675
676 vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
677                   -> DataCon -> DataCon -> [[Type]] -> [[Type]]
678                   -> VM ()
679 vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
680   = do
681       clo <- closedV
682            . inBind orig_worker
683            . polyAbstract tvs $ \abstract ->
684              liftM (abstract . vectorised)
685            $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
686
687       worker <- cloneId mkVectOcc orig_worker (exprType clo)
688       hoistBinding worker clo
689       defGlobalVar orig_worker worker
690       return ()
691   where
692     tvs     = tyConTyVars vect_tc
693     arg_tys = mkTyVarTys tvs
694     res_ty  = mkTyConApp vect_tc arg_tys
695
696     orig_worker = dataConWorkId orig_dc
697
698     mk_vect = return . mkConApp vect_dc $ map Type arg_tys
699     mk_lift = do
700                 len     <- newLocalVar FSLIT("n") intPrimTy
701                 arr_tys <- mapM mkPArrayType dc_tys
702                 args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
703                 shapes  <- shapeReplicate shape
704                                           (Var len)
705                                           (mkDataConTag vect_dc)
706
707                 empty_pre  <- mapM emptyPA (concat pre)
708                 empty_post <- mapM emptyPA (concat post)
709
710                 return . mkLams (len : args)
711                        . wrapFamInstBody arr_tc arg_tys
712                        . mkConApp arr_dc
713                        $ map Type arg_tys ++ shapes
714                                           ++ empty_pre
715                                           ++ map Var args
716                                           ++ empty_post
717
718 buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
719 buildPADict repr vect_tc prepr_tc arr_tc dfun
720   = polyAbstract tvs $ \abstract ->
721     do
722       meth_binds <- mapM (mk_method repr) paMethods
723       let meth_exprs = map (Var . fst) meth_binds
724
725       pa_dc <- builtin paDataCon
726       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
727           body = Let (Rec meth_binds) dict
728       return . mkInlineMe $ abstract body
729   where
730     tvs = tyConTyVars arr_tc
731     arg_tys = mkTyVarTys tvs
732
733     mk_method repr (name, build)
734       = localV
735       $ do
736           body <- build repr vect_tc prepr_tc arr_tc
737           var  <- newLocalVar name (exprType body)
738           return (var, mkInlineMe body)
739
740 paMethods = [(FSLIT("toPRepr"),      buildToPRepr),
741              (FSLIT("fromPRepr"),    buildFromPRepr),
742              (FSLIT("toArrPRepr"),   buildToArrPRepr),
743              (FSLIT("fromArrPRepr"), buildFromArrPRepr),
744              (FSLIT("dictPRepr"),    buildPRDict)]
745
746 -- | Split the given tycons into two sets depending on whether they have to be
747 -- converted (first list) or not (second list). The first argument contains
748 -- information about the conversion status of external tycons:
749 --
750 --   * tycons which have converted versions are mapped to True
751 --   * tycons which are not changed by vectorisation are mapped to False
752 --   * tycons which can't be converted are not elements of the map
753 --
754 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
755 classifyTyCons = classify [] []
756   where
757     classify conv keep cs [] = (conv, keep)
758     classify conv keep cs ((tcs, ds) : rs)
759       | can_convert && must_convert
760         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
761       | can_convert
762         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
763       | otherwise
764         = classify conv keep cs rs
765       where
766         refs = ds `delListFromUniqSet` tcs
767
768         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
769         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
770
771         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
772
773 -- | Compute mutually recursive groups of tycons in topological order
774 --
775 tyConGroups :: [TyCon] -> [TyConGroup]
776 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
777   where
778     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
779                                 , let ds = tyConsOfTyCon tc]
780
781     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
782     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
783       where
784         (tcs, dss) = unzip els
785
786 tyConsOfTyCon :: TyCon -> UniqSet TyCon
787 tyConsOfTyCon
788   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
789
790 tyConsOfType :: Type -> UniqSet TyCon
791 tyConsOfType ty
792   | Just ty' <- coreView ty    = tyConsOfType ty'
793 tyConsOfType (TyVarTy v)       = emptyUniqSet
794 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
795   where
796     extend | isUnLiftedTyCon tc
797            || isTupleTyCon   tc = id
798
799            | otherwise          = (`addOneToUniqSet` tc)
800
801 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
802 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
803                                  `addOneToUniqSet` funTyCon
804 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
805 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
806
807 tyConsOfTypes :: [Type] -> UniqSet TyCon
808 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
809