Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv,
2                    PAInstance, buildPADict )
3 where
4
5 #include "HsVersions.h"
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import BuildTyCl
15 import DataCon
16 import TyCon
17 import Type
18 import TypeRep
19 import Coercion
20 import FamInstEnv        ( FamInst, mkLocalFamInst )
21 import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
25 import Var               ( Var )
26 import Id                ( mkWildId )
27 import Name              ( Name, getOccName )
28 import NameEnv
29 import TysWiredIn        ( unitTy, unitTyCon, intTy, intDataCon, unitDataConId )
30 import TysPrim           ( intPrimTy )
31
32 import Unique
33 import UniqFM
34 import UniqSet
35 import Digraph           ( SCC(..), stronglyConnComp )
36
37 import Outputable
38
39 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
40 import Data.List      ( inits, tails, zipWith4, zipWith5 )
41
42 -- ----------------------------------------------------------------------------
43 -- Types
44
45 vectTyCon :: TyCon -> VM TyCon
46 vectTyCon tc
47   | isFunTyCon tc        = builtin closureTyCon
48   | isBoxedTupleTyCon tc = return tc
49   | isUnLiftedTyCon tc   = return tc
50   | otherwise = do
51                   r <- lookupTyCon tc
52                   case r of
53                     Just tc' -> return tc'
54
55                     -- FIXME: just for now
56                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
57
58 vectType :: Type -> VM Type
59 vectType ty | Just ty' <- coreView ty = vectType ty'
60 vectType (TyVarTy tv) = return $ TyVarTy tv
61 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
62 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
63 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
64                                              (mapM vectType [ty1,ty2])
65 vectType ty@(ForAllTy _ _)
66   = do
67       mdicts   <- mapM paDictArgType tyvars
68       mono_ty' <- vectType mono_ty
69       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
70   where
71     (tyvars, mono_ty) = splitForAllTys ty
72
73 vectType ty = pprPanic "vectType:" (ppr ty)
74
75 -- ----------------------------------------------------------------------------
76 -- Type definitions
77
78 type TyConGroup = ([TyCon], UniqSet TyCon)
79
80 data PAInstance = PAInstance {
81                     painstDFun      :: Var
82                   , painstOrigTyCon :: TyCon
83                   , painstVectTyCon :: TyCon
84                   , painstArrTyCon  :: TyCon
85                   }
86
87 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
88 vectTypeEnv env
89   = do
90       cs <- readGEnv $ mk_map . global_tycons
91       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
92           keep_dcs             = concatMap tyConDataCons keep_tcs
93       zipWithM_ defTyCon   keep_tcs keep_tcs
94       zipWithM_ defDataCon keep_dcs keep_dcs
95       new_tcs <- vectTyConDecls conv_tcs
96
97       let orig_tcs = keep_tcs ++ conv_tcs
98           vect_tcs  = keep_tcs ++ new_tcs
99
100       repr_tcs <- zipWithM buildPReprTyCon   orig_tcs vect_tcs
101       parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
102       dfuns    <- mapM mkPADFun vect_tcs
103       defTyConPAs (zip vect_tcs dfuns)
104       binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
105                                                         vect_tcs
106                                                         repr_tcs
107                                                         parr_tcs
108                                                         dfuns)
109
110       let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs
111
112       let new_env = extendTypeEnvList env
113                        (map ATyCon all_new_tcs
114                         ++ [ADataCon dc | tc <- all_new_tcs
115                                         , dc <- tyConDataCons tc])
116
117       return (new_env, map mkLocalFamInst (repr_tcs ++ parr_tcs), concat binds)
118   where
119     tycons = typeEnvTyCons env
120     groups = tyConGroups tycons
121
122     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
123
124     keep_tc tc = let dcs = tyConDataCons tc
125                  in
126                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
127
128
129 vectTyConDecls :: [TyCon] -> VM [TyCon]
130 vectTyConDecls tcs = fixV $ \tcs' ->
131   do
132     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
133     mapM vectTyConDecl tcs
134   where
135     lazy_zip [] _ = []
136     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
137
138 vectTyConDecl :: TyCon -> VM TyCon
139 vectTyConDecl tc
140   = do
141       name' <- cloneName mkVectTyConOcc name
142       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
143
144       liftDs $ buildAlgTyCon name'
145                              tyvars
146                              []           -- no stupid theta
147                              rhs'
148                              rec_flag     -- FIXME: is this ok?
149                              False        -- FIXME: no generics
150                              False        -- not GADT syntax
151                              Nothing      -- not a family instance
152   where
153     name   = tyConName tc
154     tyvars = tyConTyVars tc
155     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
156
157 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
158 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
159                            , is_enum   = is_enum
160                            })
161   = do
162       data_cons' <- mapM vectDataCon data_cons
163       zipWithM_ defDataCon data_cons data_cons'
164       return $ DataTyCon { data_cons = data_cons'
165                          , is_enum   = is_enum
166                          }
167
168 vectDataCon :: DataCon -> VM DataCon
169 vectDataCon dc
170   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
171   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
172   | otherwise
173   = do
174       name'    <- cloneName mkVectDataConOcc name
175       tycon'   <- vectTyCon tycon
176       arg_tys  <- mapM vectType rep_arg_tys
177
178       liftDs $ buildDataCon name'
179                             False           -- not infix
180                             (map (const NotMarkedStrict) arg_tys)
181                             []              -- no labelled fields
182                             univ_tvs
183                             []              -- no existential tvs for now
184                             []              -- no eq spec for now
185                             []              -- no context
186                             arg_tys
187                             tycon'
188   where
189     name        = dataConName dc
190     univ_tvs    = dataConUnivTyVars dc
191     rep_arg_tys = dataConRepArgTys dc
192     tycon       = dataConTyCon dc
193
194 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
195 mk_fam_inst fam_tc arg_tc
196   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
197
198 buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
199 buildPReprTyCon orig_tc vect_tc
200   = do
201       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
202       rhs_ty   <- buildPReprType vect_tc
203       prepr_tc <- builtin preprTyCon
204       liftDs $ buildSynTyCon name
205                              tyvars
206                              (SynonymTyCon rhs_ty)
207                              (Just $ mk_fam_inst prepr_tc vect_tc)
208   where
209     tyvars = tyConTyVars vect_tc
210
211
212 data Repr = ProdRepr {
213               prod_components   :: [Type]
214             , prod_tycon        :: TyCon
215             , prod_data_con     :: DataCon
216             , prod_arr_tycon    :: TyCon
217             , prod_arr_data_con :: DataCon
218             }
219
220           | SumRepr {
221               sum_components    :: [Repr]
222             , sum_tycon         :: TyCon
223             , sum_arr_tycon     :: TyCon
224             , sum_arr_data_con  :: DataCon
225             }
226
227           | IdRepr Type
228
229           | VoidRepr {
230               void_tycon        :: TyCon
231             , void_bottom       :: CoreExpr
232             }
233
234 voidRepr :: VM Repr
235 voidRepr
236   = do
237       tycon <- builtin voidTyCon
238       var   <- builtin voidVar
239       return $ VoidRepr {
240                  void_tycon  = tycon
241                , void_bottom = Var var
242                }
243
244 unboxedProductRepr :: [Type] -> VM Repr
245 unboxedProductRepr []   = voidRepr
246 unboxedProductRepr [ty] = return $ IdRepr ty
247 unboxedProductRepr tys  = boxedProductRepr tys
248
249 boxedProductRepr :: [Type] -> VM Repr
250 boxedProductRepr tys
251   = do
252       tycon <- builtin (prodTyCon arity)
253       let [data_con] = tyConDataCons tycon
254
255       (arr_tycon, _) <- parrayReprTyCon $ mkTyConApp tycon tys
256       let [arr_data_con] = tyConDataCons arr_tycon
257
258       return $ ProdRepr {
259                  prod_components   = tys
260                , prod_tycon        = tycon
261                , prod_data_con     = data_con
262                , prod_arr_tycon    = arr_tycon
263                , prod_arr_data_con = arr_data_con
264                }
265   where
266     arity = length tys
267
268 sumRepr :: [Repr] -> VM Repr
269 sumRepr []     = voidRepr
270 sumRepr [repr] = boxRepr repr
271 sumRepr reprs
272   = do
273       tycon <- builtin (sumTyCon arity)
274       (arr_tycon, _) <- parrayReprTyCon
275                       . mkTyConApp tycon
276                       $ map reprType reprs
277
278       let [arr_data_con] = tyConDataCons arr_tycon
279
280       return $ SumRepr {
281                  sum_components   = reprs
282                , sum_tycon        = tycon
283                , sum_arr_tycon    = arr_tycon
284                , sum_arr_data_con = arr_data_con
285                }
286   where
287     arity = length reprs
288
289 boxRepr :: Repr -> VM Repr
290 boxRepr (VoidRepr {}) = boxedProductRepr []
291 boxRepr (IdRepr ty)   = boxedProductRepr [ty]
292 boxRepr repr          = return repr
293
294 reprType :: Repr -> Type
295 reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
296   = mkTyConApp tycon tys
297 reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
298   = mkTyConApp tycon (map reprType reprs)
299 reprType (IdRepr ty) = ty
300 reprType (VoidRepr { void_tycon = tycon }) = mkTyConApp tycon []
301
302 arrReprType :: Repr -> VM Type
303 arrReprType = mkPArrayType . reprType
304
305 arrShapeTys :: Repr -> VM [Type]
306 arrShapeTys (SumRepr  {})
307   = do
308       int_arr <- builtin parrayIntPrimTyCon
309       return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
310 arrShapeTys (ProdRepr {}) = return [intPrimTy]
311 arrShapeTys (IdRepr _)    = return []
312 arrShapeTys (VoidRepr {}) = return [intPrimTy]
313
314 arrShapeVars :: Repr -> VM [Var]
315 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
316
317 replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
318 replicateShape (ProdRepr {}) len _ = return [len]
319 replicateShape (SumRepr {})  len tag
320   = do
321       rep <- builtin replicatePAIntPrimVar
322       up  <- builtin upToPAIntPrimVar
323       return [len, Var rep `mkApps` [len, tag], Var up `App` len]
324 replicateShape (IdRepr _) _ _ = return []
325 replicateShape (VoidRepr {}) len _ = return [len]
326
327 arrReprElemTys :: Repr -> VM [[Type]]
328 arrReprElemTys (SumRepr { sum_components = prods })
329   = mapM arrProdElemTys prods
330 arrReprElemTys prod@(ProdRepr {})
331   = do
332       tys <- arrProdElemTys prod
333       return [tys]
334 arrReprElemTys (IdRepr ty) = return [[ty]]
335 arrReprElemTys (VoidRepr { void_tycon = tycon })
336   = return [[mkTyConApp tycon []]]
337
338 arrProdElemTys (ProdRepr { prod_components = [] })
339   = do
340       void <- builtin voidTyCon
341       return [mkTyConApp void []]
342 arrProdElemTys (ProdRepr { prod_components = tys })
343   = return tys
344 arrProdElemTys (IdRepr ty) = return [ty]
345 arrProdElemTys (VoidRepr { void_tycon = tycon })
346   = return [mkTyConApp tycon []]
347
348 arrReprTys :: Repr -> VM [[Type]]
349 arrReprTys repr = mapM (mapM mkPArrayType) =<< arrReprElemTys repr
350
351 arrReprVars :: Repr -> VM [[Var]]
352 arrReprVars repr
353   = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr
354
355 mkRepr :: TyCon -> VM Repr
356 mkRepr vect_tc
357   = sumRepr =<< mapM unboxedProductRepr rep_tys
358   where
359     rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
360
361 buildPReprType :: TyCon -> VM Type
362 buildPReprType = liftM reprType . mkRepr
363
364 buildToPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
365 buildToPRepr repr vect_tc prepr_tc _
366   = do
367       arg    <- newLocalVar FSLIT("x") arg_ty
368       result <- to_repr repr (Var arg)
369
370       return . Lam arg
371              . wrapFamInstBody prepr_tc var_tys
372              $ result
373   where
374     var_tys = mkTyVarTys $ tyConTyVars vect_tc
375     arg_ty  = mkTyConApp vect_tc var_tys
376     res_ty  = reprType repr
377
378     cons    = tyConDataCons vect_tc
379     [con]   = cons
380
381     to_repr (SumRepr { sum_components = prods
382                      , sum_tycon      = tycon })
383             expr
384       = do
385           (vars, bodies) <- mapAndUnzipM prod_alt prods
386           return . Case expr (mkWildId (exprType expr)) res_ty
387                  $ zipWith4 mk_alt cons vars (tyConDataCons tycon) bodies
388       where
389         mk_alt con vars sum_con body
390           = (DataAlt con, vars, mkConApp sum_con (ty_args ++ [body]))
391
392         ty_args = map (Type . reprType) prods
393
394     to_repr prod expr
395       = do
396           (vars, body) <- prod_alt prod
397           return $ Case expr (mkWildId (exprType expr)) res_ty
398                    [(DataAlt con, vars, body)]
399
400     prod_alt (ProdRepr { prod_components = tys
401                        , prod_data_con   = data_con })
402       = do
403           vars <- mapM (newLocalVar FSLIT("r")) tys
404           return (vars, mkConApp data_con (map Type tys ++ map Var vars))
405
406     prod_alt (IdRepr ty)
407       = do
408           var <- newLocalVar FSLIT("y") ty
409           return ([var], Var var)
410
411     prod_alt (VoidRepr { void_bottom = bottom })
412       = return ([], bottom)
413
414
415 buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
416 buildFromPRepr repr vect_tc prepr_tc _
417   = do
418       arg_ty <- mkPReprType res_ty
419       arg    <- newLocalVar FSLIT("x") arg_ty
420
421       liftM (Lam arg)
422            . from_repr repr
423            $ unwrapFamInstScrut prepr_tc var_tys (Var arg)
424   where
425     var_tys = mkTyVarTys $ tyConTyVars vect_tc
426     res_ty  = mkTyConApp vect_tc var_tys
427
428     cons    = map (`mkConApp` map Type var_tys) (tyConDataCons vect_tc)
429     [con]   = cons
430
431     from_repr repr@(SumRepr { sum_components = prods
432                             , sum_tycon      = tycon })
433               expr
434       = do
435           vars   <- mapM (newLocalVar FSLIT("x")) (map reprType prods)
436           bodies <- sequence . zipWith3 from_prod prods cons
437                              $ map Var vars
438           return . Case expr (mkWildId (reprType repr)) res_ty
439                  $ zipWith3 sum_alt (tyConDataCons tycon) vars bodies
440       where
441         sum_alt data_con var body = (DataAlt data_con, [var], body)
442
443     from_repr repr expr = from_prod repr con expr
444
445     from_prod prod@(ProdRepr { prod_components = tys
446                              , prod_data_con   = data_con })
447               con
448               expr
449       = do
450           vars <- mapM (newLocalVar FSLIT("y")) tys
451           return $ Case expr (mkWildId (reprType prod)) res_ty
452                    [(DataAlt data_con, vars, con `mkVarApps` vars)]
453
454     from_prod (IdRepr _) con expr
455        = return $ con `App` expr
456
457     from_prod (VoidRepr {}) con expr
458        = return con
459
460 buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
461 buildToArrPRepr repr vect_tc prepr_tc arr_tc
462   = do
463       arg_ty     <- mkPArrayType el_ty
464       arg        <- newLocalVar FSLIT("xs") arg_ty
465
466       res_ty     <- mkPArrayType (reprType repr)
467
468       shape_vars <- arrShapeVars repr
469       repr_vars  <- arrReprVars  repr
470
471       parray_co  <- mkBuiltinCo parrayTyCon
472
473       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
474           co           = mkAppCoercion parray_co
475                        . mkSymCoercion
476                        $ mkTyConApp repr_co var_tys
477
478           scrut   = unwrapFamInstScrut arr_tc var_tys (Var arg)
479
480       result <- to_repr shape_vars repr_vars repr
481
482       return . Lam arg
483              . mkCoerce co
484              $ Case scrut (mkWildId (mkTyConApp arr_tc var_tys)) res_ty
485                [(DataAlt arr_dc, shape_vars ++ concat repr_vars, result)]
486   where
487     var_tys = mkTyVarTys $ tyConTyVars vect_tc
488     el_ty   = mkTyConApp vect_tc var_tys
489
490     [arr_dc] = tyConDataCons arr_tc
491
492     to_repr shape_vars@(len_var : _)
493             repr_vars
494             (SumRepr { sum_components   = prods
495                      , sum_arr_tycon    = tycon
496                      , sum_arr_data_con = data_con })
497       = do
498           exprs <- zipWithM to_prod repr_vars prods
499
500           return . wrapFamInstBody tycon tys
501                  . mkConApp data_con
502                  $ map Type tys ++ map Var shape_vars ++ exprs
503       where
504         tys = map reprType prods
505
506     to_repr [len_var]
507             [repr_vars]
508             (ProdRepr { prod_components   = tys
509                       , prod_arr_tycon    = tycon
510                       , prod_arr_data_con = data_con })
511        = return . wrapFamInstBody tycon tys
512                 . mkConApp data_con
513                 $ map Type tys ++ map Var (len_var : repr_vars)
514
515     to_prod repr_vars@(r : _)
516             (ProdRepr { prod_components   = tys
517                       , prod_arr_tycon    = tycon
518                       , prod_arr_data_con = data_con })
519       = do
520           len <- lengthPA (Var r)
521           return . wrapFamInstBody tycon tys
522                  . mkConApp data_con
523                  $ map Type tys ++ len : map Var repr_vars
524
525     to_prod [var] (IdRepr ty)   = return (Var var)
526     to_prod [var] (VoidRepr {}) = return (Var var)
527
528
529 buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
530 buildFromArrPRepr repr vect_tc prepr_tc arr_tc
531   = do
532       arg_ty     <- mkPArrayType =<< mkPReprType el_ty
533       arg        <- newLocalVar FSLIT("xs") arg_ty
534
535       res_ty     <- mkPArrayType el_ty
536
537       shape_vars <- arrShapeVars repr
538       repr_vars  <- arrReprVars  repr
539
540       parray_co  <- mkBuiltinCo parrayTyCon
541
542       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
543           co           = mkAppCoercion parray_co
544                        $ mkTyConApp repr_co var_tys
545
546           scrut  = mkCoerce co (Var arg)
547
548           result = wrapFamInstBody arr_tc var_tys
549                  . mkConApp arr_dc
550                  $ map Type var_tys ++ map Var (shape_vars ++ concat repr_vars)
551
552       liftM (Lam arg)
553             (from_repr repr scrut shape_vars repr_vars res_ty result)
554   where
555     var_tys = mkTyVarTys $ tyConTyVars vect_tc
556     el_ty   = mkTyConApp vect_tc var_tys
557
558     [arr_dc] = tyConDataCons arr_tc
559
560     from_repr (SumRepr { sum_components   = prods
561                        , sum_arr_tycon    = tycon
562                        , sum_arr_data_con = data_con })
563               expr
564               shape_vars
565               repr_vars
566               res_ty
567               body
568       = do
569           vars <- mapM (newLocalVar FSLIT("xs")) =<< mapM arrReprType prods
570           result <- go prods repr_vars vars body
571
572           let scrut = unwrapFamInstScrut tycon ty_args expr
573           return . Case scrut (mkWildId scrut_ty) res_ty
574                  $ [(DataAlt data_con, shape_vars ++ vars, result)]
575       where
576         ty_args  = map reprType prods
577         scrut_ty = mkTyConApp tycon ty_args
578
579         go [] [] [] body = return body
580         go (prod : prods) (repr_vars : rss) (var : vars) body
581           = do
582               shape_vars <- mapM (newLocalVar FSLIT("s")) =<< arrShapeTys prod
583
584               from_prod prod (Var var) shape_vars repr_vars res_ty
585                 =<< go prods rss vars body
586
587     from_repr repr expr shape_vars [repr_vars] res_ty body
588       = from_prod repr expr shape_vars repr_vars res_ty body
589
590     from_prod prod@(ProdRepr { prod_components = tys
591                              , prod_arr_tycon  = tycon
592                              , prod_arr_data_con = data_con })
593               expr
594               shape_vars
595               repr_vars
596               res_ty
597               body
598       = do
599           let scrut    = unwrapFamInstScrut tycon tys expr
600               scrut_ty = mkTyConApp tycon tys
601           ty <- arrReprType prod
602
603           return $ Case scrut (mkWildId scrut_ty) res_ty
604                    [(DataAlt data_con, shape_vars ++ repr_vars, body)]
605
606     from_prod (IdRepr ty)
607               expr
608               shape_vars
609               [repr_var]
610               res_ty
611               body
612       = return $ Let (NonRec repr_var expr) body
613
614     from_prod (VoidRepr {})
615               expr
616               shape_vars
617               [repr_var]
618               res_ty
619               body
620       = return $ Let (NonRec repr_var expr) body
621
622 buildPRDictRepr :: Repr -> VM CoreExpr
623 buildPRDictRepr (VoidRepr { void_tycon = tycon })
624   = prDFunOfTyCon tycon
625 buildPRDictRepr (IdRepr ty) = mkPR ty
626 buildPRDictRepr (ProdRepr {
627                    prod_components = tys
628                  , prod_tycon      = tycon
629                  })
630   = do
631       prs  <- mapM mkPR tys
632       dfun <- prDFunOfTyCon tycon
633       return $ dfun `mkTyApps` tys `mkApps` prs
634
635 buildPRDictRepr (SumRepr {
636                    sum_components = prods
637                  , sum_tycon      = tycon })
638   = do
639       prs  <- mapM buildPRDictRepr prods
640       dfun <- prDFunOfTyCon tycon
641       return $ dfun `mkTyApps` map reprType prods `mkApps` prs
642
643 buildPRDict :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
644 buildPRDict repr vect_tc prepr_tc _
645   = do
646       dict  <- buildPRDictRepr repr
647
648       pr_co <- mkBuiltinCo prTyCon
649       let co = mkAppCoercion pr_co
650              . mkSymCoercion
651              $ mkTyConApp arg_co var_tys
652
653       return $ mkCoerce co dict
654   where
655     var_tys = mkTyVarTys $ tyConTyVars vect_tc
656
657     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
658
659 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
660 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
661   do
662     name'  <- cloneName mkPArrayTyConOcc orig_name
663     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
664     parray <- builtin parrayTyCon
665
666     liftDs $ buildAlgTyCon name'
667                            tyvars
668                            []          -- no stupid theta
669                            rhs
670                            rec_flag    -- FIXME: is this ok?
671                            False       -- FIXME: no generics
672                            False       -- not GADT syntax
673                            (Just $ mk_fam_inst parray vect_tc)
674   where
675     orig_name = tyConName orig_tc
676     tyvars = tyConTyVars vect_tc
677     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
678
679
680 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
681 buildPArrayTyConRhs orig_name vect_tc repr_tc
682   = do
683       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
684       return $ DataTyCon { data_cons = [data_con], is_enum = False }
685
686 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
687 buildPArrayDataCon orig_name vect_tc repr_tc
688   = do
689       dc_name  <- cloneName mkPArrayDataConOcc orig_name
690       repr     <- mkRepr vect_tc
691
692       shape_tys <- arrShapeTys repr
693       repr_tys  <- arrReprTys  repr
694
695       let tys = shape_tys ++ concat repr_tys
696
697       liftDs $ buildDataCon dc_name
698                             False                  -- not infix
699                             (map (const NotMarkedStrict) tys)
700                             []                     -- no field labels
701                             (tyConTyVars vect_tc)
702                             []                     -- no existentials
703                             []                     -- no eq spec
704                             []                     -- no context
705                             tys
706                             repr_tc
707
708 mkPADFun :: TyCon -> VM Var
709 mkPADFun vect_tc
710   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
711
712 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
713                    -> VM [(Var, CoreExpr)]
714 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
715   = do
716       repr  <- mkRepr vect_tc
717       vectDataConWorkers repr orig_tc vect_tc arr_tc
718       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
719       binds <- takeHoisted
720       return $ (dfun, dict) : binds
721   where
722     orig_dcs = tyConDataCons orig_tc
723     vect_dcs = tyConDataCons vect_tc
724     [arr_dc] = tyConDataCons arr_tc
725
726     repr_tys = map dataConRepArgTys vect_dcs
727
728 vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
729                    -> VM ()
730 vectDataConWorkers repr orig_tc vect_tc arr_tc
731   = do
732       arr_tys <- arrReprElemTys repr
733       bs <- sequence
734           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
735           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
736                                  rep_tys
737                                  (inits arr_tys)
738                                  (tail $ tails arr_tys)
739       mapM_ (uncurry hoistBinding) bs
740   where
741     tyvars   = tyConTyVars vect_tc
742     var_tys  = mkTyVarTys tyvars
743     ty_args  = map Type var_tys
744
745     res_ty   = mkTyConApp vect_tc var_tys
746
747     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
748
749     [arr_dc] = tyConDataCons arr_tc
750
751     mk_data_con con tys pre post
752       = liftM2 (,) (vect_data_con con)
753                    (lift_data_con tys (concat pre)
754                                       (concat post)
755                                       (mkDataConTag con))
756
757     vect_data_con con = return $ mkConApp con ty_args
758     lift_data_con tys pre_tys post_tys tag
759       = do
760           len  <- builtin liftingContext
761           args <- mapM (newLocalVar FSLIT("xs"))
762                   =<< mapM mkPArrayType tys
763
764           shape <- replicateShape repr (Var len) tag
765           repr  <- mk_arr_repr (Var len) (map Var args)
766
767           pre   <- mapM emptyPA pre_tys
768           post  <- mapM emptyPA post_tys
769
770           return . mkLams (len : args)
771                  . wrapFamInstBody arr_tc var_tys
772                  . mkConApp arr_dc
773                  $ ty_args ++ shape ++ pre ++ repr ++ post
774
775     mk_arr_repr len []
776       = do
777           units <- replicatePA len (Var unitDataConId)
778           return [units]
779
780     mk_arr_repr len arrs = return arrs
781
782     def_worker data_con arg_tys mk_body
783       = do
784           body <- closedV
785                 . inBind orig_worker
786                 . polyAbstract tyvars $ \abstract ->
787                   liftM (abstract . vectorised)
788                 $ buildClosures tyvars [] arg_tys res_ty mk_body
789
790           vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
791           defGlobalVar orig_worker vect_worker
792           return (vect_worker, body)
793       where
794         orig_worker = dataConWorkId data_con
795
796 buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
797 buildPADict repr vect_tc prepr_tc arr_tc dfun
798   = polyAbstract tvs $ \abstract ->
799     do
800       meth_binds <- mapM (mk_method repr) paMethods
801       let meth_exprs = map (Var . fst) meth_binds
802
803       pa_dc <- builtin paDataCon
804       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
805           body = Let (Rec meth_binds) dict
806       return . mkInlineMe $ abstract body
807   where
808     tvs = tyConTyVars arr_tc
809     arg_tys = mkTyVarTys tvs
810
811     mk_method repr (name, build)
812       = localV
813       $ do
814           body <- build repr vect_tc prepr_tc arr_tc
815           var  <- newLocalVar name (exprType body)
816           return (var, mkInlineMe body)
817
818 paMethods = [(FSLIT("toPRepr"),      buildToPRepr),
819              (FSLIT("fromPRepr"),    buildFromPRepr),
820              (FSLIT("toArrPRepr"),   buildToArrPRepr),
821              (FSLIT("fromArrPRepr"), buildFromArrPRepr),
822              (FSLIT("dictPRepr"),    buildPRDict)]
823
824 -- | Split the given tycons into two sets depending on whether they have to be
825 -- converted (first list) or not (second list). The first argument contains
826 -- information about the conversion status of external tycons:
827 --
828 --   * tycons which have converted versions are mapped to True
829 --   * tycons which are not changed by vectorisation are mapped to False
830 --   * tycons which can't be converted are not elements of the map
831 --
832 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
833 classifyTyCons = classify [] []
834   where
835     classify conv keep cs [] = (conv, keep)
836     classify conv keep cs ((tcs, ds) : rs)
837       | can_convert && must_convert
838         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
839       | can_convert
840         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
841       | otherwise
842         = classify conv keep cs rs
843       where
844         refs = ds `delListFromUniqSet` tcs
845
846         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
847         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
848
849         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
850
851 -- | Compute mutually recursive groups of tycons in topological order
852 --
853 tyConGroups :: [TyCon] -> [TyConGroup]
854 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
855   where
856     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
857                                 , let ds = tyConsOfTyCon tc]
858
859     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
860     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
861       where
862         (tcs, dss) = unzip els
863
864 tyConsOfTyCon :: TyCon -> UniqSet TyCon
865 tyConsOfTyCon
866   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
867
868 tyConsOfType :: Type -> UniqSet TyCon
869 tyConsOfType ty
870   | Just ty' <- coreView ty    = tyConsOfType ty'
871 tyConsOfType (TyVarTy v)       = emptyUniqSet
872 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
873   where
874     extend | isUnLiftedTyCon tc
875            || isTupleTyCon   tc = id
876
877            | otherwise          = (`addOneToUniqSet` tc)
878
879 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
880 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
881                                  `addOneToUniqSet` funTyCon
882 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
883 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
884
885 tyConsOfTypes :: [Type] -> UniqSet TyCon
886 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
887