Simplify generation of PR dictionaries for products
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv,
2                    PAInstance, buildPADict )
3 where
4
5 #include "HsVersions.h"
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import BuildTyCl
15 import DataCon
16 import TyCon
17 import Type
18 import TypeRep
19 import Coercion
20 import FamInstEnv        ( FamInst, mkLocalFamInst )
21 import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
25 import Var               ( Var )
26 import Id                ( mkWildId )
27 import Name              ( Name, getOccName )
28 import NameEnv
29 import TysWiredIn        ( unitTy, unitTyCon, intTy, intDataCon, unitDataConId )
30 import TysPrim           ( intPrimTy )
31
32 import Unique
33 import UniqFM
34 import UniqSet
35 import Digraph           ( SCC(..), stronglyConnComp )
36
37 import Outputable
38
39 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_ )
40 import Data.List      ( inits, tails, zipWith4, zipWith5 )
41
42 -- ----------------------------------------------------------------------------
43 -- Types
44
45 vectTyCon :: TyCon -> VM TyCon
46 vectTyCon tc
47   | isFunTyCon tc        = builtin closureTyCon
48   | isBoxedTupleTyCon tc = return tc
49   | isUnLiftedTyCon tc   = return tc
50   | otherwise = do
51                   r <- lookupTyCon tc
52                   case r of
53                     Just tc' -> return tc'
54
55                     -- FIXME: just for now
56                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
57
58 vectType :: Type -> VM Type
59 vectType ty | Just ty' <- coreView ty = vectType ty'
60 vectType (TyVarTy tv) = return $ TyVarTy tv
61 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
62 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
63 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
64                                              (mapM vectType [ty1,ty2])
65 vectType ty@(ForAllTy _ _)
66   = do
67       mdicts   <- mapM paDictArgType tyvars
68       mono_ty' <- vectType mono_ty
69       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
70   where
71     (tyvars, mono_ty) = splitForAllTys ty
72
73 vectType ty = pprPanic "vectType:" (ppr ty)
74
75 -- ----------------------------------------------------------------------------
76 -- Type definitions
77
78 type TyConGroup = ([TyCon], UniqSet TyCon)
79
80 data PAInstance = PAInstance {
81                     painstDFun      :: Var
82                   , painstOrigTyCon :: TyCon
83                   , painstVectTyCon :: TyCon
84                   , painstArrTyCon  :: TyCon
85                   }
86
87 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
88 vectTypeEnv env
89   = do
90       cs <- readGEnv $ mk_map . global_tycons
91       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
92           keep_dcs             = concatMap tyConDataCons keep_tcs
93       zipWithM_ defTyCon   keep_tcs keep_tcs
94       zipWithM_ defDataCon keep_dcs keep_dcs
95       new_tcs <- vectTyConDecls conv_tcs
96
97       let orig_tcs = keep_tcs ++ conv_tcs
98           vect_tcs  = keep_tcs ++ new_tcs
99
100       repr_tcs <- zipWithM buildPReprTyCon   orig_tcs vect_tcs
101       parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
102       dfuns    <- mapM mkPADFun vect_tcs
103       defTyConPAs (zip vect_tcs dfuns)
104       binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
105                                                         vect_tcs
106                                                         repr_tcs
107                                                         parr_tcs
108                                                         dfuns)
109
110       let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs
111
112       let new_env = extendTypeEnvList env
113                        (map ATyCon all_new_tcs
114                         ++ [ADataCon dc | tc <- all_new_tcs
115                                         , dc <- tyConDataCons tc])
116
117       return (new_env, map mkLocalFamInst (repr_tcs ++ parr_tcs), concat binds)
118   where
119     tycons = typeEnvTyCons env
120     groups = tyConGroups tycons
121
122     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
123
124     keep_tc tc = let dcs = tyConDataCons tc
125                  in
126                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
127
128
129 vectTyConDecls :: [TyCon] -> VM [TyCon]
130 vectTyConDecls tcs = fixV $ \tcs' ->
131   do
132     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
133     mapM vectTyConDecl tcs
134   where
135     lazy_zip [] _ = []
136     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
137
138 vectTyConDecl :: TyCon -> VM TyCon
139 vectTyConDecl tc
140   = do
141       name' <- cloneName mkVectTyConOcc name
142       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
143
144       liftDs $ buildAlgTyCon name'
145                              tyvars
146                              []           -- no stupid theta
147                              rhs'
148                              rec_flag     -- FIXME: is this ok?
149                              False        -- FIXME: no generics
150                              False        -- not GADT syntax
151                              Nothing      -- not a family instance
152   where
153     name   = tyConName tc
154     tyvars = tyConTyVars tc
155     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
156
157 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
158 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
159                            , is_enum   = is_enum
160                            })
161   = do
162       data_cons' <- mapM vectDataCon data_cons
163       zipWithM_ defDataCon data_cons data_cons'
164       return $ DataTyCon { data_cons = data_cons'
165                          , is_enum   = is_enum
166                          }
167
168 vectDataCon :: DataCon -> VM DataCon
169 vectDataCon dc
170   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
171   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
172   | otherwise
173   = do
174       name'    <- cloneName mkVectDataConOcc name
175       tycon'   <- vectTyCon tycon
176       arg_tys  <- mapM vectType rep_arg_tys
177
178       liftDs $ buildDataCon name'
179                             False           -- not infix
180                             (map (const NotMarkedStrict) arg_tys)
181                             []              -- no labelled fields
182                             univ_tvs
183                             []              -- no existential tvs for now
184                             []              -- no eq spec for now
185                             []              -- no context
186                             arg_tys
187                             tycon'
188   where
189     name        = dataConName dc
190     univ_tvs    = dataConUnivTyVars dc
191     rep_arg_tys = dataConRepArgTys dc
192     tycon       = dataConTyCon dc
193
194 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
195 mk_fam_inst fam_tc arg_tc
196   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
197
198 buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
199 buildPReprTyCon orig_tc vect_tc
200   = do
201       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
202       rhs_ty   <- buildPReprType vect_tc
203       prepr_tc <- builtin preprTyCon
204       liftDs $ buildSynTyCon name
205                              tyvars
206                              (SynonymTyCon rhs_ty)
207                              (Just $ mk_fam_inst prepr_tc vect_tc)
208   where
209     tyvars = tyConTyVars vect_tc
210
211 data TyConRepr = ProdRepr {
212                    repr_prod_arg_tys   :: [Type]
213                  , repr_prod_tycon     :: TyCon
214                  , repr_prod_data_con  :: DataCon
215                  , repr_type           :: Type
216                  }
217                | SumRepr {
218                    repr_tys            :: [[Type]]
219                  , repr_prod_tycons    :: [Maybe TyCon]
220                  , repr_prod_data_cons :: [Maybe DataCon]
221                  , repr_prod_tys       :: [Type]
222                  , repr_sum_tycon      :: TyCon
223                  , repr_type           :: Type
224                  }
225
226 arrShapeTys :: TyConRepr -> VM [Type]
227 arrShapeTys (ProdRepr {}) = return [intPrimTy]
228 arrShapeTys (SumRepr  {})
229   = do
230       uarr <- builtin uarrTyCon
231       return [intPrimTy, mkTyConApp uarr [intTy]]
232
233 arrReprTys :: TyConRepr -> VM [Type]
234 arrReprTys (ProdRepr { repr_prod_arg_tys = tys })
235   = mapM mkPArrayType tys
236 arrReprTys (SumRepr { repr_tys = tys })
237   = concat `liftM` mapM (mapM mkPArrayType) (map mk_prod tys)
238   where
239     mk_prod []  = [unitTy]
240     mk_prod tys = tys
241       
242
243 mkTyConRepr :: TyCon -> VM TyConRepr
244 mkTyConRepr vect_tc
245   | is_product
246   = let
247       [prod_arg_tys] = repr_tys
248     in
249     do
250       prod_tycon <- builtin (prodTyCon $ length prod_arg_tys)
251       let [prod_data_con] = tyConDataCons prod_tycon
252
253       return $ ProdRepr {
254                  repr_prod_arg_tys  = prod_arg_tys
255                , repr_prod_tycon    = prod_tycon
256                , repr_prod_data_con = prod_data_con
257                , repr_type          = mkTyConApp prod_tycon prod_arg_tys
258                }
259
260   | otherwise
261   = do
262       uarr <- builtin uarrTyCon
263       prod_tycons  <- mapM (mk_tycon prodTyCon) repr_tys
264       let prod_tys = zipWith mk_tc_app_maybe prod_tycons repr_tys
265       sum_tycon    <- builtin (sumTyCon $ length repr_tys)
266       arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) repr_tys
267
268       return $ SumRepr {
269                  repr_tys            = repr_tys
270                , repr_prod_tycons    = prod_tycons
271                , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
272                , repr_prod_tys       = prod_tys
273                , repr_sum_tycon      = sum_tycon
274                , repr_type           = mkTyConApp sum_tycon prod_tys
275                }
276   where
277     tyvars    = tyConTyVars vect_tc
278     data_cons = tyConDataCons vect_tc
279     repr_tys  = map dataConRepArgTys data_cons
280
281     is_product | [_] <- data_cons = True
282                | otherwise        = False
283
284     mk_shape uarr = intPrimTy : mk_sel uarr
285
286     mk_sel uarr | is_product = []
287                 | otherwise  = [uarr_int, uarr_int]
288       where
289         uarr_int = mkTyConApp uarr [intTy]
290
291     mk_tycon get_tc tys
292       | n > 1     = builtin (Just . get_tc n)
293       | otherwise = return Nothing
294       where n = length tys
295
296     mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
297
298     mk_tc_app_maybe Nothing   []   = unitTy
299     mk_tc_app_maybe Nothing   [ty] = ty
300     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
301
302     arr_repr_elem_tys []  = [unitTy]
303     arr_repr_elem_tys tys = tys
304
305 buildPReprType :: TyCon -> VM Type
306 buildPReprType = liftM repr_type . mkTyConRepr
307
308 buildToPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
309 buildToPRepr (ProdRepr {
310                 repr_prod_arg_tys  = prod_arg_tys
311               , repr_prod_data_con = prod_data_con
312               , repr_type          = repr_type
313               })
314              vect_tc prepr_tc _
315   = do
316       arg  <- newLocalVar FSLIT("x") arg_ty
317       vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
318
319       return . Lam arg
320              . wrapFamInstBody prepr_tc var_tys
321              $ Case (Var arg) (mkWildId arg_ty) repr_type
322                [(DataAlt data_con, vars,
323                  mkConApp prod_data_con (map Type prod_arg_tys ++ map Var vars))]
324   where
325     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
326     arg_ty     = mkTyConApp vect_tc var_tys
327     [data_con] = tyConDataCons vect_tc
328
329 buildToPRepr (SumRepr {
330                 repr_tys            = repr_tys
331               , repr_prod_data_cons = prod_data_cons
332               , repr_prod_tys       = prod_tys
333               , repr_sum_tycon      = sum_tycon
334               , repr_type           = repr_type
335               })
336               vect_tc prepr_tc _
337   = do
338       arg  <- newLocalVar FSLIT("x") arg_ty
339       vars <- mapM (mapM (newLocalVar FSLIT("x"))) repr_tys
340
341       return . Lam arg
342              . wrapFamInstBody prepr_tc var_tys
343              . Case (Var arg) (mkWildId arg_ty) repr_type
344              . zipWith4 mk_alt data_cons vars sum_data_cons
345              . zipWith3 mk_prod prod_data_cons repr_tys $ map (map Var) vars
346   where
347     var_tys   = mkTyVarTys $ tyConTyVars vect_tc
348     arg_ty    = mkTyConApp vect_tc var_tys
349     data_cons = tyConDataCons vect_tc
350
351     sum_data_cons = tyConDataCons sum_tycon
352
353     mk_alt dc vars sum_dc expr = (DataAlt dc, vars,
354                                   mkConApp sum_dc (map Type prod_tys ++ [expr]))
355
356     mk_prod _         _   []     = Var unitDataConId
357     mk_prod _         _   [expr] = expr
358     mk_prod (Just dc) tys exprs  = mkConApp dc (map Type tys ++ exprs)
359
360 buildFromPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
361 buildFromPRepr (ProdRepr {
362                   repr_prod_arg_tys  = prod_arg_tys
363                 , repr_prod_data_con = prod_data_con
364                 , repr_type          = repr_type
365                 })
366              vect_tc prepr_tc _
367   = do
368       arg_ty <- mkPReprType res_ty
369       arg    <- newLocalVar FSLIT("x") arg_ty
370       vars   <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
371
372       return . Lam arg
373              $ Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
374                     (mkWildId repr_type)
375                     res_ty
376                [(DataAlt prod_data_con, vars,
377                  mkConApp data_con (map Type var_tys ++ map Var vars))]
378   where
379     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
380     ty_args    = map Type var_tys
381     res_ty     = mkTyConApp vect_tc var_tys
382     [data_con] = tyConDataCons vect_tc
383
384 buildFromPRepr (SumRepr {
385                 repr_tys            = repr_tys
386               , repr_prod_data_cons = prod_data_cons
387               , repr_prod_tys       = prod_tys
388               , repr_sum_tycon      = sum_tycon
389               , repr_type           = repr_type
390               })
391               vect_tc prepr_tc _
392   = do
393       arg_ty <- mkPReprType res_ty
394       arg    <- newLocalVar FSLIT("x") arg_ty
395
396       liftM (Lam arg
397              . Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
398                     (mkWildId repr_type)
399                     res_ty
400              . zipWith mk_alt sum_data_cons)
401             (sequence
402              $ zipWith4 un_prod data_cons prod_data_cons prod_tys repr_tys)
403   where
404     var_tys   = mkTyVarTys $ tyConTyVars vect_tc
405     ty_args   = map Type var_tys
406     res_ty    = mkTyConApp vect_tc var_tys
407     data_cons = tyConDataCons vect_tc
408
409     sum_data_cons = tyConDataCons sum_tycon
410
411     un_prod dc _ _ []
412       = do
413           var <- newLocalVar FSLIT("u") unitTy
414           return (var, mkConApp dc ty_args)
415     un_prod dc _ _ [ty]
416       = do
417           var <- newLocalVar FSLIT("x") ty
418           return (var, mkConApp dc (ty_args ++ [Var var]))
419
420     un_prod dc (Just prod_dc) prod_ty tys
421       = do
422           vars  <- mapM (newLocalVar FSLIT("x")) tys
423           pv    <- newLocalVar FSLIT("p") prod_ty
424
425           let res  = mkConApp dc (ty_args ++ map Var vars)
426               expr = Case (Var pv) (mkWildId prod_ty) res_ty
427                         [(DataAlt prod_dc, vars, res)]
428
429           return (pv, expr)
430
431     mk_alt sum_dc (var, expr) = (DataAlt sum_dc, [var], expr)
432
433
434 buildToArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
435 {-
436 buildToArrPRepr (ProdRepr {
437                    repr_prod_arg_tys  = prod_arg_tys
438                  , repr_prod_data_con = prod_data_con
439                  , repr_type          = repr_type
440                  })
441                 vect_tc prepr_tc _
442   = do
443       arg_ty  <- mkPArratType el_ty
444       rep_tys <- mapM mkPArrayType prod_arg_tys
445
446       
447   where
448     var_tys = mkTyVarTys $ tyConTyVars vect_tc
449     el_ty   = mkTyConApp vect_tc var_tys
450 -}
451 buildToArrPRepr _ _ _ _ = return (Var unitDataConId)
452 {-
453 buildToArrPRepr _ vect_tc prepr_tc arr_tc
454   = do
455       arg_ty  <- mkPArrayType el_ty
456       rep_tys <- mapM (mapM mkPArrayType) rep_el_tys
457
458       arg     <- newLocalVar FSLIT("xs") arg_ty
459       bndrss  <- mapM (mapM (newLocalVar FSLIT("ys"))) rep_tys
460       len     <- newLocalVar FSLIT("len") intPrimTy
461       sel     <- newLocalVar FSLIT("sel") =<< mkPArrayType intTy
462
463       let add_sel xs | has_selector = sel : xs
464                      | otherwise    = xs
465
466           all_bndrs = len : add_sel (concat bndrss)
467
468       res      <- parrayCoerce prepr_tc var_tys
469                 =<< mkToArrPRepr (Var len) (Var sel) (map (map Var) bndrss)
470       res_ty   <- mkPArrayType =<< mkPReprType el_ty
471
472       return . Lam arg
473              $ Case (unwrapFamInstScrut arr_tc var_tys (Var arg))
474                     (mkWildId (mkTyConApp arr_tc var_tys))
475                     res_ty
476                     [(DataAlt arr_dc, all_bndrs, res)]
477   where
478     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
479     el_ty      = mkTyConApp vect_tc var_tys
480     data_cons  = tyConDataCons vect_tc
481     rep_el_tys = map dataConRepArgTys data_cons
482
483     [arr_dc]   = tyConDataCons arr_tc
484
485     has_selector | [_] <- data_cons = False
486                  | otherwise        = True
487 -}
488
489 buildFromArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
490 buildFromArrPRepr _ _ _ _ = return (Var unitDataConId)
491
492 buildPRDict :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
493 buildPRDict (ProdRepr {
494                repr_prod_arg_tys = prod_arg_tys
495              , repr_prod_tycon   = prod_tycon
496              })
497             vect_tc prepr_tc _
498   = do
499       prs  <- mapM mkPR prod_arg_tys
500       dfun <- prDFunOfTyCon prod_tycon
501       return $ dfun `mkTyApps` prod_arg_tys `mkApps` prs
502
503 buildPRDict (SumRepr {
504                 repr_tys         = repr_tys
505               , repr_prod_tycons = prod_tycons
506               , repr_prod_tys    = prod_tys
507               , repr_sum_tycon   = sum_tycon
508               })
509             vect_tc prepr_tc _
510   = do
511       prs      <- mapM (mapM mkPR) repr_tys
512       prod_prs <- sequence $ zipWith3 mk_prod_pr prod_tycons repr_tys prs
513       sum_dfun <- prDFunOfTyCon sum_tycon
514       prCoerce prepr_tc var_tys
515         $ sum_dfun `mkTyApps` prod_tys `mkApps` prod_prs
516   where
517     var_tys = mkTyVarTys $ tyConTyVars vect_tc
518
519     mk_prod_pr _         _   []   = prDFunOfTyCon unitTyCon
520     mk_prod_pr _         _   [pr] = return pr
521     mk_prod_pr (Just tc) tys prs
522       = do
523           dfun <- prDFunOfTyCon tc
524           return $ dfun `mkTyApps` tys `mkApps` prs
525
526 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
527 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
528   do
529     name'  <- cloneName mkPArrayTyConOcc orig_name
530     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
531     parray <- builtin parrayTyCon
532
533     liftDs $ buildAlgTyCon name'
534                            tyvars
535                            []          -- no stupid theta
536                            rhs
537                            rec_flag    -- FIXME: is this ok?
538                            False       -- FIXME: no generics
539                            False       -- not GADT syntax
540                            (Just $ mk_fam_inst parray vect_tc)
541   where
542     orig_name = tyConName orig_tc
543     tyvars = tyConTyVars vect_tc
544     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
545     
546
547 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
548 buildPArrayTyConRhs orig_name vect_tc repr_tc
549   = do
550       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
551       return $ DataTyCon { data_cons = [data_con], is_enum = False }
552
553 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
554 buildPArrayDataCon orig_name vect_tc repr_tc
555   = do
556       dc_name  <- cloneName mkPArrayDataConOcc orig_name
557       repr     <- mkTyConRepr vect_tc
558
559       shape_tys <- arrShapeTys repr
560       repr_tys  <- arrReprTys  repr
561
562       let tys = shape_tys ++ repr_tys
563
564       liftDs $ buildDataCon dc_name
565                             False                  -- not infix
566                             (map (const NotMarkedStrict) tys)
567                             []                     -- no field labels
568                             (tyConTyVars vect_tc)
569                             []                     -- no existentials
570                             []                     -- no eq spec
571                             []                     -- no context
572                             tys
573                             repr_tc
574
575 mkPADFun :: TyCon -> VM Var
576 mkPADFun vect_tc
577   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
578
579 data Shape = Shape {
580                shapeReprTys    :: [Type]
581              , shapeStrictness :: [StrictnessMark]
582              , shapeLength     :: [CoreExpr] -> VM CoreExpr
583              , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
584              }
585
586 tyConShape :: TyCon -> VM Shape
587 tyConShape vect_tc
588   | isProductTyCon vect_tc
589   = return $ Shape {
590                 shapeReprTys    = [intPrimTy]
591               , shapeStrictness = [NotMarkedStrict]
592               , shapeLength     = \[len] -> return len
593               , shapeReplicate  = \len _ -> return [len]
594               }
595
596   | otherwise
597   = do
598       repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
599       return $ Shape {
600                  shapeReprTys    = [repr_ty]
601                , shapeStrictness = [MarkedStrict]
602                , shapeLength     = \[sel] -> lengthPA sel
603                , shapeReplicate  = \len n -> do
604                                                e <- replicatePA len n
605                                                return [e]
606                }
607
608 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
609                    -> VM [(Var, CoreExpr)]
610 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
611   = do
612       shape <- tyConShape vect_tc
613       repr  <- mkTyConRepr vect_tc
614       sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
615                           orig_dcs
616                           vect_dcs
617                           (inits repr_tys)
618                           (tails repr_tys))
619       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
620       binds <- takeHoisted
621       return $ (dfun, dict) : binds
622   where
623     orig_dcs = tyConDataCons orig_tc
624     vect_dcs = tyConDataCons vect_tc
625     [arr_dc] = tyConDataCons arr_tc
626
627     repr_tys = map dataConRepArgTys vect_dcs
628
629 vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
630                   -> DataCon -> DataCon -> [[Type]] -> [[Type]]
631                   -> VM ()
632 vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
633   = do
634       clo <- closedV
635            . inBind orig_worker
636            . polyAbstract tvs $ \abstract ->
637              liftM (abstract . vectorised)
638            $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
639
640       worker <- cloneId mkVectOcc orig_worker (exprType clo)
641       hoistBinding worker clo
642       defGlobalVar orig_worker worker
643       return ()
644   where
645     tvs     = tyConTyVars vect_tc
646     arg_tys = mkTyVarTys tvs
647     res_ty  = mkTyConApp vect_tc arg_tys
648
649     orig_worker = dataConWorkId orig_dc
650
651     mk_vect = return . mkConApp vect_dc $ map Type arg_tys
652     mk_lift = do
653                 len     <- newLocalVar FSLIT("n") intPrimTy
654                 arr_tys <- mapM mkPArrayType dc_tys
655                 args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
656                 shapes  <- shapeReplicate shape
657                                           (Var len)
658                                           (mkDataConTag vect_dc)
659                 
660                 empty_pre  <- mapM emptyPA (concat pre)
661                 empty_post <- mapM emptyPA (concat post)
662
663                 return . mkLams (len : args)
664                        . wrapFamInstBody arr_tc arg_tys
665                        . mkConApp arr_dc
666                        $ map Type arg_tys ++ shapes
667                                           ++ empty_pre
668                                           ++ map Var args
669                                           ++ empty_post
670
671 buildPADict :: TyConRepr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
672 buildPADict repr vect_tc prepr_tc arr_tc dfun
673   = polyAbstract tvs $ \abstract ->
674     do
675       meth_binds <- mapM (mk_method repr) paMethods
676       let meth_exprs = map (Var . fst) meth_binds
677
678       pa_dc <- builtin paDataCon
679       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
680           body = Let (Rec meth_binds) dict
681       return . mkInlineMe $ abstract body
682   where
683     tvs = tyConTyVars arr_tc
684     arg_tys = mkTyVarTys tvs
685
686     mk_method repr (name, build)
687       = localV
688       $ do
689           body <- build repr vect_tc prepr_tc arr_tc
690           var  <- newLocalVar name (exprType body)
691           return (var, mkInlineMe body)
692           
693 paMethods = [(FSLIT("toPRepr"),      buildToPRepr),
694              (FSLIT("fromPRepr"),    buildFromPRepr),
695              (FSLIT("toArrPRepr"),   buildToArrPRepr),
696              (FSLIT("fromArrPRepr"), buildFromArrPRepr),
697              (FSLIT("dictPRepr"),    buildPRDict)]
698
699 -- | Split the given tycons into two sets depending on whether they have to be
700 -- converted (first list) or not (second list). The first argument contains
701 -- information about the conversion status of external tycons:
702 -- 
703 --   * tycons which have converted versions are mapped to True
704 --   * tycons which are not changed by vectorisation are mapped to False
705 --   * tycons which can't be converted are not elements of the map
706 --
707 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
708 classifyTyCons = classify [] []
709   where
710     classify conv keep cs [] = (conv, keep)
711     classify conv keep cs ((tcs, ds) : rs)
712       | can_convert && must_convert
713         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
714       | can_convert
715         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
716       | otherwise
717         = classify conv keep cs rs
718       where
719         refs = ds `delListFromUniqSet` tcs
720
721         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
722         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
723
724         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
725     
726 -- | Compute mutually recursive groups of tycons in topological order
727 --
728 tyConGroups :: [TyCon] -> [TyConGroup]
729 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
730   where
731     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
732                                 , let ds = tyConsOfTyCon tc]
733
734     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
735     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
736       where
737         (tcs, dss) = unzip els
738
739 tyConsOfTyCon :: TyCon -> UniqSet TyCon
740 tyConsOfTyCon 
741   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
742
743 tyConsOfType :: Type -> UniqSet TyCon
744 tyConsOfType ty
745   | Just ty' <- coreView ty    = tyConsOfType ty'
746 tyConsOfType (TyVarTy v)       = emptyUniqSet
747 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
748   where
749     extend | isUnLiftedTyCon tc
750            || isTupleTyCon   tc = id
751
752            | otherwise          = (`addOneToUniqSet` tc)
753
754 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
755 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
756                                  `addOneToUniqSet` funTyCon
757 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
758 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
759
760 tyConsOfTypes :: [Type] -> UniqSet TyCon
761 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
762