Complete PA dictionary generation for product types
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectType, vectTypeEnv,
2                    PAInstance, buildPADict )
3 where
4
5 #include "HsVersions.h"
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import BuildTyCl
15 import DataCon
16 import TyCon
17 import Type
18 import TypeRep
19 import Coercion
20 import FamInstEnv        ( FamInst, mkLocalFamInst )
21 import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
25 import Var               ( Var )
26 import Id                ( mkWildId )
27 import Name              ( Name, getOccName )
28 import NameEnv
29 import TysWiredIn        ( unitTy, unitTyCon, intTy, intDataCon, unitDataConId )
30 import TysPrim           ( intPrimTy )
31
32 import Unique
33 import UniqFM
34 import UniqSet
35 import Digraph           ( SCC(..), stronglyConnComp )
36
37 import Outputable
38
39 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_ )
40 import Data.List      ( inits, tails, zipWith4, zipWith5 )
41
42 -- ----------------------------------------------------------------------------
43 -- Types
44
45 vectTyCon :: TyCon -> VM TyCon
46 vectTyCon tc
47   | isFunTyCon tc        = builtin closureTyCon
48   | isBoxedTupleTyCon tc = return tc
49   | isUnLiftedTyCon tc   = return tc
50   | otherwise = do
51                   r <- lookupTyCon tc
52                   case r of
53                     Just tc' -> return tc'
54
55                     -- FIXME: just for now
56                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
57
58 vectType :: Type -> VM Type
59 vectType ty | Just ty' <- coreView ty = vectType ty'
60 vectType (TyVarTy tv) = return $ TyVarTy tv
61 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
62 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
63 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
64                                              (mapM vectType [ty1,ty2])
65 vectType ty@(ForAllTy _ _)
66   = do
67       mdicts   <- mapM paDictArgType tyvars
68       mono_ty' <- vectType mono_ty
69       return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
70   where
71     (tyvars, mono_ty) = splitForAllTys ty
72
73 vectType ty = pprPanic "vectType:" (ppr ty)
74
75 -- ----------------------------------------------------------------------------
76 -- Type definitions
77
78 type TyConGroup = ([TyCon], UniqSet TyCon)
79
80 data PAInstance = PAInstance {
81                     painstDFun      :: Var
82                   , painstOrigTyCon :: TyCon
83                   , painstVectTyCon :: TyCon
84                   , painstArrTyCon  :: TyCon
85                   }
86
87 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
88 vectTypeEnv env
89   = do
90       cs <- readGEnv $ mk_map . global_tycons
91       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
92           keep_dcs             = concatMap tyConDataCons keep_tcs
93       zipWithM_ defTyCon   keep_tcs keep_tcs
94       zipWithM_ defDataCon keep_dcs keep_dcs
95       new_tcs <- vectTyConDecls conv_tcs
96
97       let orig_tcs = keep_tcs ++ conv_tcs
98           vect_tcs  = keep_tcs ++ new_tcs
99
100       repr_tcs <- zipWithM buildPReprTyCon   orig_tcs vect_tcs
101       parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
102       dfuns    <- mapM mkPADFun vect_tcs
103       defTyConPAs (zip vect_tcs dfuns)
104       binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
105                                                         vect_tcs
106                                                         repr_tcs
107                                                         parr_tcs
108                                                         dfuns)
109
110       let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs
111
112       let new_env = extendTypeEnvList env
113                        (map ATyCon all_new_tcs
114                         ++ [ADataCon dc | tc <- all_new_tcs
115                                         , dc <- tyConDataCons tc])
116
117       return (new_env, map mkLocalFamInst (repr_tcs ++ parr_tcs), concat binds)
118   where
119     tycons = typeEnvTyCons env
120     groups = tyConGroups tycons
121
122     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
123
124     keep_tc tc = let dcs = tyConDataCons tc
125                  in
126                  defTyCon tc tc >> zipWithM_ defDataCon dcs dcs
127
128
129 vectTyConDecls :: [TyCon] -> VM [TyCon]
130 vectTyConDecls tcs = fixV $ \tcs' ->
131   do
132     mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
133     mapM vectTyConDecl tcs
134   where
135     lazy_zip [] _ = []
136     lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys
137
138 vectTyConDecl :: TyCon -> VM TyCon
139 vectTyConDecl tc
140   = do
141       name' <- cloneName mkVectTyConOcc name
142       rhs'  <- vectAlgTyConRhs (algTyConRhs tc)
143
144       liftDs $ buildAlgTyCon name'
145                              tyvars
146                              []           -- no stupid theta
147                              rhs'
148                              rec_flag     -- FIXME: is this ok?
149                              False        -- FIXME: no generics
150                              False        -- not GADT syntax
151                              Nothing      -- not a family instance
152   where
153     name   = tyConName tc
154     tyvars = tyConTyVars tc
155     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
156
157 vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
158 vectAlgTyConRhs (DataTyCon { data_cons = data_cons
159                            , is_enum   = is_enum
160                            })
161   = do
162       data_cons' <- mapM vectDataCon data_cons
163       zipWithM_ defDataCon data_cons data_cons'
164       return $ DataTyCon { data_cons = data_cons'
165                          , is_enum   = is_enum
166                          }
167
168 vectDataCon :: DataCon -> VM DataCon
169 vectDataCon dc
170   | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
171   | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
172   | otherwise
173   = do
174       name'    <- cloneName mkVectDataConOcc name
175       tycon'   <- vectTyCon tycon
176       arg_tys  <- mapM vectType rep_arg_tys
177
178       liftDs $ buildDataCon name'
179                             False           -- not infix
180                             (map (const NotMarkedStrict) arg_tys)
181                             []              -- no labelled fields
182                             univ_tvs
183                             []              -- no existential tvs for now
184                             []              -- no eq spec for now
185                             []              -- no context
186                             arg_tys
187                             tycon'
188   where
189     name        = dataConName dc
190     univ_tvs    = dataConUnivTyVars dc
191     rep_arg_tys = dataConRepArgTys dc
192     tycon       = dataConTyCon dc
193
194 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
195 mk_fam_inst fam_tc arg_tc
196   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
197
198 buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
199 buildPReprTyCon orig_tc vect_tc
200   = do
201       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
202       rhs_ty   <- buildPReprType vect_tc
203       prepr_tc <- builtin preprTyCon
204       liftDs $ buildSynTyCon name
205                              tyvars
206                              (SynonymTyCon rhs_ty)
207                              (Just $ mk_fam_inst prepr_tc vect_tc)
208   where
209     tyvars = tyConTyVars vect_tc
210
211 data TyConRepr = ProdRepr {
212                    repr_prod_arg_tys      :: [Type]
213                  , repr_prod_tycon        :: TyCon
214                  , repr_prod_data_con     :: DataCon
215                  , repr_prod_arr_tycon    :: TyCon
216                  , repr_prod_arr_data_con :: DataCon
217                  , repr_type              :: Type
218                  }
219                | SumRepr {
220                    repr_tys            :: [[Type]]
221                  , repr_prod_tycons    :: [Maybe TyCon]
222                  , repr_prod_data_cons :: [Maybe DataCon]
223                  , repr_prod_tys       :: [Type]
224                  , repr_sum_tycon      :: TyCon
225                  , repr_type           :: Type
226                  }
227
228 arrShapeTys :: TyConRepr -> VM [Type]
229 arrShapeTys (ProdRepr {}) = return [intPrimTy]
230 arrShapeTys (SumRepr  {})
231   = do
232       uarr <- builtin uarrTyCon
233       return [intPrimTy, mkTyConApp uarr [intTy]]
234
235 arrReprTys :: TyConRepr -> VM [Type]
236 arrReprTys (ProdRepr { repr_prod_arg_tys = tys })
237   = mapM mkPArrayType tys
238 arrReprTys (SumRepr { repr_tys = tys })
239   = concat `liftM` mapM (mapM mkPArrayType) (map mk_prod tys)
240   where
241     mk_prod []  = [unitTy]
242     mk_prod tys = tys
243       
244
245 mkTyConRepr :: TyCon -> VM TyConRepr
246 mkTyConRepr vect_tc
247   | is_product
248   = let
249       [prod_arg_tys] = repr_tys
250       arity          = length prod_arg_tys
251     in
252     do
253       prod_tycon <- builtin (prodTyCon arity)
254       let [prod_data_con] = tyConDataCons prod_tycon
255
256       (arr_tycon, _) <- parrayReprTyCon
257                       . mkTyConApp prod_tycon
258                       $ replicate arity unitTy
259
260       let [arr_data_con] = tyConDataCons arr_tycon
261
262       return $ ProdRepr {
263                  repr_prod_arg_tys      = prod_arg_tys
264                , repr_prod_tycon        = prod_tycon
265                , repr_prod_data_con     = prod_data_con
266                , repr_prod_arr_tycon    = arr_tycon
267                , repr_prod_arr_data_con = arr_data_con
268                , repr_type              = mkTyConApp prod_tycon prod_arg_tys
269                }
270
271   | otherwise
272   = do
273       uarr <- builtin uarrTyCon
274       prod_tycons  <- mapM (mk_tycon prodTyCon) repr_tys
275       let prod_tys = zipWith mk_tc_app_maybe prod_tycons repr_tys
276       sum_tycon    <- builtin (sumTyCon $ length repr_tys)
277       arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) repr_tys
278
279       return $ SumRepr {
280                  repr_tys            = repr_tys
281                , repr_prod_tycons    = prod_tycons
282                , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
283                , repr_prod_tys       = prod_tys
284                , repr_sum_tycon      = sum_tycon
285                , repr_type           = mkTyConApp sum_tycon prod_tys
286                }
287   where
288     tyvars    = tyConTyVars vect_tc
289     data_cons = tyConDataCons vect_tc
290     repr_tys  = map dataConRepArgTys data_cons
291
292     is_product | [_] <- data_cons = True
293                | otherwise        = False
294
295     mk_shape uarr = intPrimTy : mk_sel uarr
296
297     mk_sel uarr | is_product = []
298                 | otherwise  = [uarr_int, uarr_int]
299       where
300         uarr_int = mkTyConApp uarr [intTy]
301
302     mk_tycon get_tc tys
303       | n > 1     = builtin (Just . get_tc n)
304       | otherwise = return Nothing
305       where n = length tys
306
307     mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
308
309     mk_tc_app_maybe Nothing   []   = unitTy
310     mk_tc_app_maybe Nothing   [ty] = ty
311     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
312
313     arr_repr_elem_tys []  = [unitTy]
314     arr_repr_elem_tys tys = tys
315
316 buildPReprType :: TyCon -> VM Type
317 buildPReprType = liftM repr_type . mkTyConRepr
318
319 buildToPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
320 buildToPRepr (ProdRepr {
321                 repr_prod_arg_tys  = prod_arg_tys
322               , repr_prod_data_con = prod_data_con
323               , repr_type          = repr_type
324               })
325              vect_tc prepr_tc _
326   = do
327       arg  <- newLocalVar FSLIT("x") arg_ty
328       vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
329
330       return . Lam arg
331              . wrapFamInstBody prepr_tc var_tys
332              $ Case (Var arg) (mkWildId arg_ty) repr_type
333                [(DataAlt data_con, vars,
334                  mkConApp prod_data_con (map Type prod_arg_tys ++ map Var vars))]
335   where
336     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
337     arg_ty     = mkTyConApp vect_tc var_tys
338     [data_con] = tyConDataCons vect_tc
339
340 buildToPRepr (SumRepr {
341                 repr_tys            = repr_tys
342               , repr_prod_data_cons = prod_data_cons
343               , repr_prod_tys       = prod_tys
344               , repr_sum_tycon      = sum_tycon
345               , repr_type           = repr_type
346               })
347               vect_tc prepr_tc _
348   = do
349       arg  <- newLocalVar FSLIT("x") arg_ty
350       vars <- mapM (mapM (newLocalVar FSLIT("x"))) repr_tys
351
352       return . Lam arg
353              . wrapFamInstBody prepr_tc var_tys
354              . Case (Var arg) (mkWildId arg_ty) repr_type
355              . zipWith4 mk_alt data_cons vars sum_data_cons
356              . zipWith3 mk_prod prod_data_cons repr_tys $ map (map Var) vars
357   where
358     var_tys   = mkTyVarTys $ tyConTyVars vect_tc
359     arg_ty    = mkTyConApp vect_tc var_tys
360     data_cons = tyConDataCons vect_tc
361
362     sum_data_cons = tyConDataCons sum_tycon
363
364     mk_alt dc vars sum_dc expr = (DataAlt dc, vars,
365                                   mkConApp sum_dc (map Type prod_tys ++ [expr]))
366
367     mk_prod _         _   []     = Var unitDataConId
368     mk_prod _         _   [expr] = expr
369     mk_prod (Just dc) tys exprs  = mkConApp dc (map Type tys ++ exprs)
370
371 buildFromPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
372 buildFromPRepr (ProdRepr {
373                   repr_prod_arg_tys  = prod_arg_tys
374                 , repr_prod_data_con = prod_data_con
375                 , repr_type          = repr_type
376                 })
377              vect_tc prepr_tc _
378   = do
379       arg_ty <- mkPReprType res_ty
380       arg    <- newLocalVar FSLIT("x") arg_ty
381       vars   <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
382
383       return . Lam arg
384              $ Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
385                     (mkWildId repr_type)
386                     res_ty
387                [(DataAlt prod_data_con, vars,
388                  mkConApp data_con (map Type var_tys ++ map Var vars))]
389   where
390     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
391     ty_args    = map Type var_tys
392     res_ty     = mkTyConApp vect_tc var_tys
393     [data_con] = tyConDataCons vect_tc
394
395 buildFromPRepr (SumRepr {
396                 repr_tys            = repr_tys
397               , repr_prod_data_cons = prod_data_cons
398               , repr_prod_tys       = prod_tys
399               , repr_sum_tycon      = sum_tycon
400               , repr_type           = repr_type
401               })
402               vect_tc prepr_tc _
403   = do
404       arg_ty <- mkPReprType res_ty
405       arg    <- newLocalVar FSLIT("x") arg_ty
406
407       liftM (Lam arg
408              . Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
409                     (mkWildId repr_type)
410                     res_ty
411              . zipWith mk_alt sum_data_cons)
412             (sequence
413              $ zipWith4 un_prod data_cons prod_data_cons prod_tys repr_tys)
414   where
415     var_tys   = mkTyVarTys $ tyConTyVars vect_tc
416     ty_args   = map Type var_tys
417     res_ty    = mkTyConApp vect_tc var_tys
418     data_cons = tyConDataCons vect_tc
419
420     sum_data_cons = tyConDataCons sum_tycon
421
422     un_prod dc _ _ []
423       = do
424           var <- newLocalVar FSLIT("u") unitTy
425           return (var, mkConApp dc ty_args)
426     un_prod dc _ _ [ty]
427       = do
428           var <- newLocalVar FSLIT("x") ty
429           return (var, mkConApp dc (ty_args ++ [Var var]))
430
431     un_prod dc (Just prod_dc) prod_ty tys
432       = do
433           vars  <- mapM (newLocalVar FSLIT("x")) tys
434           pv    <- newLocalVar FSLIT("p") prod_ty
435
436           let res  = mkConApp dc (ty_args ++ map Var vars)
437               expr = Case (Var pv) (mkWildId prod_ty) res_ty
438                         [(DataAlt prod_dc, vars, res)]
439
440           return (pv, expr)
441
442     mk_alt sum_dc (var, expr) = (DataAlt sum_dc, [var], expr)
443
444
445 buildToArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
446 buildToArrPRepr repr@(ProdRepr {
447                    repr_prod_arg_tys      = prod_arg_tys
448                  , repr_prod_arr_tycon    = prod_arr_tycon
449                  , repr_prod_arr_data_con = prod_arr_data_con
450                  , repr_type              = repr_type
451                  })
452                 vect_tc prepr_tc arr_tc
453   = do
454       arg_ty     <- mkPArrayType el_ty
455       shape_tys  <- arrShapeTys repr
456       arr_tys    <- arrReprTys repr
457       res_ty     <- mkPArrayType repr_type
458       rep_el_ty  <- mkPReprType el_ty
459
460       arg        <- newLocalVar FSLIT("xs") arg_ty
461       shape_vars <- mapM (newLocalVar FSLIT("sh")) shape_tys
462       rep_vars   <- mapM (newLocalVar FSLIT("ys")) arr_tys
463       
464       let vars = shape_vars ++ rep_vars
465
466       parray_co  <- mkBuiltinCo parrayTyCon
467
468       let res = wrapFamInstBody prod_arr_tycon prod_arg_tys
469               . mkConApp prod_arr_data_con
470               $ map Type prod_arg_tys ++ map Var vars
471
472           Just repr_co = tyConFamilyCoercion_maybe prepr_tc
473           co  = mkAppCoercion parray_co
474               . mkSymCoercion
475               $ mkTyConApp repr_co var_tys
476
477       return . Lam arg
478              . mkCoerce co
479              $ Case (unwrapFamInstScrut arr_tc var_tys (Var arg))
480                     (mkWildId (mkTyConApp arr_tc var_tys))
481                     res_ty
482                [(DataAlt arr_dc, vars, res)]
483   where
484     var_tys = mkTyVarTys $ tyConTyVars vect_tc
485     el_ty   = mkTyConApp vect_tc var_tys
486
487     [arr_dc] = tyConDataCons arr_tc
488
489
490 buildToArrPRepr _ _ _ _ = return (Var unitDataConId)
491 {-
492 buildToArrPRepr _ vect_tc prepr_tc arr_tc
493   = do
494       arg_ty  <- mkPArrayType el_ty
495       rep_tys <- mapM (mapM mkPArrayType) rep_el_tys
496
497       arg     <- newLocalVar FSLIT("xs") arg_ty
498       bndrss  <- mapM (mapM (newLocalVar FSLIT("ys"))) rep_tys
499       len     <- newLocalVar FSLIT("len") intPrimTy
500       sel     <- newLocalVar FSLIT("sel") =<< mkPArrayType intTy
501
502       let add_sel xs | has_selector = sel : xs
503                      | otherwise    = xs
504
505           all_bndrs = len : add_sel (concat bndrss)
506
507       res      <- parrayCoerce prepr_tc var_tys
508                 =<< mkToArrPRepr (Var len) (Var sel) (map (map Var) bndrss)
509       res_ty   <- mkPArrayType =<< mkPReprType el_ty
510
511       return . Lam arg
512              $ Case (unwrapFamInstScrut arr_tc var_tys (Var arg))
513                     (mkWildId (mkTyConApp arr_tc var_tys))
514                     res_ty
515                     [(DataAlt arr_dc, all_bndrs, res)]
516   where
517     var_tys    = mkTyVarTys $ tyConTyVars vect_tc
518     el_ty      = mkTyConApp vect_tc var_tys
519     data_cons  = tyConDataCons vect_tc
520     rep_el_tys = map dataConRepArgTys data_cons
521
522     [arr_dc]   = tyConDataCons arr_tc
523
524     has_selector | [_] <- data_cons = False
525                  | otherwise        = True
526 -}
527
528 buildFromArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
529 buildFromArrPRepr repr@(ProdRepr {
530                           repr_prod_arg_tys      = prod_arg_tys
531                         , repr_prod_arr_tycon    = prod_arr_tycon
532                         , repr_prod_arr_data_con = prod_arr_data_con
533                         , repr_type              = repr_type
534                         })
535                        vect_tc prepr_tc arr_tc
536   = do
537       rep_el_ty  <- mkPReprType el_ty
538       arg_ty     <- mkPArrayType rep_el_ty
539       shape_tys  <- arrShapeTys repr
540       arr_tys    <- arrReprTys repr
541       res_ty     <- mkPArrayType el_ty
542
543       arg        <- newLocalVar FSLIT("xs") arg_ty
544       shape_vars <- mapM (newLocalVar FSLIT("sh")) shape_tys
545       rep_vars   <- mapM (newLocalVar FSLIT("ys")) arr_tys
546
547       let vars = shape_vars ++ rep_vars
548
549       parray_co  <- mkBuiltinCo parrayTyCon
550
551       let res = wrapFamInstBody arr_tc var_tys
552               . mkConApp arr_dc
553               $ map Type var_tys ++ map Var vars
554
555           Just repr_co = tyConFamilyCoercion_maybe prepr_tc
556           co  = mkAppCoercion parray_co
557               $ mkTyConApp repr_co var_tys
558
559           scrut = unwrapFamInstScrut prod_arr_tycon prod_arg_tys
560                 $ mkCoerce co (Var arg)
561
562       return . Lam arg
563              $ Case (scrut)
564                     (mkWildId (mkTyConApp prod_arr_tycon prod_arg_tys))
565                     res_ty
566                [(DataAlt prod_arr_data_con, vars, res)]
567   where
568     var_tys = mkTyVarTys $ tyConTyVars vect_tc
569     el_ty   = mkTyConApp vect_tc var_tys
570
571     [arr_dc] = tyConDataCons arr_tc
572 buildFromArrPRepr _ _ _ _ = return (Var unitDataConId)
573
574 buildPRDictRepr :: TyConRepr -> VM CoreExpr
575 buildPRDictRepr (ProdRepr {
576                    repr_prod_arg_tys = prod_arg_tys
577                  , repr_prod_tycon   = prod_tycon
578                  })
579   = do
580       prs  <- mapM mkPR prod_arg_tys
581       dfun <- prDFunOfTyCon prod_tycon
582       return $ dfun `mkTyApps` prod_arg_tys `mkApps` prs
583
584 buildPRDictRepr (SumRepr {
585                    repr_tys         = repr_tys
586                  , repr_prod_tycons = prod_tycons
587                  , repr_prod_tys    = prod_tys
588                  , repr_sum_tycon   = sum_tycon
589                  })
590   = do
591       prs      <- mapM (mapM mkPR) repr_tys
592       prod_prs <- sequence $ zipWith3 mk_prod_pr prod_tycons repr_tys prs
593       sum_dfun <- prDFunOfTyCon sum_tycon
594       return $ sum_dfun `mkTyApps` prod_tys `mkApps` prod_prs
595   where
596     mk_prod_pr _         _   []   = prDFunOfTyCon unitTyCon
597     mk_prod_pr _         _   [pr] = return pr
598     mk_prod_pr (Just tc) tys prs
599       = do
600           dfun <- prDFunOfTyCon tc
601           return $ dfun `mkTyApps` tys `mkApps` prs
602
603 buildPRDict :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
604 buildPRDict repr vect_tc prepr_tc _
605   = do
606       dict  <- buildPRDictRepr repr
607
608       pr_co <- mkBuiltinCo prTyCon
609       let co = mkAppCoercion pr_co
610              . mkSymCoercion
611              $ mkTyConApp arg_co var_tys
612
613       return $ mkCoerce co dict
614   where
615     var_tys = mkTyVarTys $ tyConTyVars vect_tc
616
617     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
618
619 buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
620 buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
621   do
622     name'  <- cloneName mkPArrayTyConOcc orig_name
623     rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
624     parray <- builtin parrayTyCon
625
626     liftDs $ buildAlgTyCon name'
627                            tyvars
628                            []          -- no stupid theta
629                            rhs
630                            rec_flag    -- FIXME: is this ok?
631                            False       -- FIXME: no generics
632                            False       -- not GADT syntax
633                            (Just $ mk_fam_inst parray vect_tc)
634   where
635     orig_name = tyConName orig_tc
636     tyvars = tyConTyVars vect_tc
637     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
638     
639
640 buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
641 buildPArrayTyConRhs orig_name vect_tc repr_tc
642   = do
643       data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
644       return $ DataTyCon { data_cons = [data_con], is_enum = False }
645
646 buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
647 buildPArrayDataCon orig_name vect_tc repr_tc
648   = do
649       dc_name  <- cloneName mkPArrayDataConOcc orig_name
650       repr     <- mkTyConRepr vect_tc
651
652       shape_tys <- arrShapeTys repr
653       repr_tys  <- arrReprTys  repr
654
655       let tys = shape_tys ++ repr_tys
656
657       liftDs $ buildDataCon dc_name
658                             False                  -- not infix
659                             (map (const NotMarkedStrict) tys)
660                             []                     -- no field labels
661                             (tyConTyVars vect_tc)
662                             []                     -- no existentials
663                             []                     -- no eq spec
664                             []                     -- no context
665                             tys
666                             repr_tc
667
668 mkPADFun :: TyCon -> VM Var
669 mkPADFun vect_tc
670   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
671
672 data Shape = Shape {
673                shapeReprTys    :: [Type]
674              , shapeStrictness :: [StrictnessMark]
675              , shapeLength     :: [CoreExpr] -> VM CoreExpr
676              , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
677              }
678
679 tyConShape :: TyCon -> VM Shape
680 tyConShape vect_tc
681   | isProductTyCon vect_tc
682   = return $ Shape {
683                 shapeReprTys    = [intPrimTy]
684               , shapeStrictness = [NotMarkedStrict]
685               , shapeLength     = \[len] -> return len
686               , shapeReplicate  = \len _ -> return [len]
687               }
688
689   | otherwise
690   = do
691       repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
692       return $ Shape {
693                  shapeReprTys    = [repr_ty]
694                , shapeStrictness = [MarkedStrict]
695                , shapeLength     = \[sel] -> lengthPA sel
696                , shapeReplicate  = \len n -> do
697                                                e <- replicatePA len n
698                                                return [e]
699                }
700
701 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
702                    -> VM [(Var, CoreExpr)]
703 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
704   = do
705       shape <- tyConShape vect_tc
706       repr  <- mkTyConRepr vect_tc
707       sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
708                           orig_dcs
709                           vect_dcs
710                           (inits repr_tys)
711                           (tails repr_tys))
712       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
713       binds <- takeHoisted
714       return $ (dfun, dict) : binds
715   where
716     orig_dcs = tyConDataCons orig_tc
717     vect_dcs = tyConDataCons vect_tc
718     [arr_dc] = tyConDataCons arr_tc
719
720     repr_tys = map dataConRepArgTys vect_dcs
721
722 vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
723                   -> DataCon -> DataCon -> [[Type]] -> [[Type]]
724                   -> VM ()
725 vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
726   = do
727       clo <- closedV
728            . inBind orig_worker
729            . polyAbstract tvs $ \abstract ->
730              liftM (abstract . vectorised)
731            $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
732
733       worker <- cloneId mkVectOcc orig_worker (exprType clo)
734       hoistBinding worker clo
735       defGlobalVar orig_worker worker
736       return ()
737   where
738     tvs     = tyConTyVars vect_tc
739     arg_tys = mkTyVarTys tvs
740     res_ty  = mkTyConApp vect_tc arg_tys
741
742     orig_worker = dataConWorkId orig_dc
743
744     mk_vect = return . mkConApp vect_dc $ map Type arg_tys
745     mk_lift = do
746                 len     <- newLocalVar FSLIT("n") intPrimTy
747                 arr_tys <- mapM mkPArrayType dc_tys
748                 args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
749                 shapes  <- shapeReplicate shape
750                                           (Var len)
751                                           (mkDataConTag vect_dc)
752                 
753                 empty_pre  <- mapM emptyPA (concat pre)
754                 empty_post <- mapM emptyPA (concat post)
755
756                 return . mkLams (len : args)
757                        . wrapFamInstBody arr_tc arg_tys
758                        . mkConApp arr_dc
759                        $ map Type arg_tys ++ shapes
760                                           ++ empty_pre
761                                           ++ map Var args
762                                           ++ empty_post
763
764 buildPADict :: TyConRepr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
765 buildPADict repr vect_tc prepr_tc arr_tc dfun
766   = polyAbstract tvs $ \abstract ->
767     do
768       meth_binds <- mapM (mk_method repr) paMethods
769       let meth_exprs = map (Var . fst) meth_binds
770
771       pa_dc <- builtin paDataCon
772       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
773           body = Let (Rec meth_binds) dict
774       return . mkInlineMe $ abstract body
775   where
776     tvs = tyConTyVars arr_tc
777     arg_tys = mkTyVarTys tvs
778
779     mk_method repr (name, build)
780       = localV
781       $ do
782           body <- build repr vect_tc prepr_tc arr_tc
783           var  <- newLocalVar name (exprType body)
784           return (var, mkInlineMe body)
785           
786 paMethods = [(FSLIT("toPRepr"),      buildToPRepr),
787              (FSLIT("fromPRepr"),    buildFromPRepr),
788              (FSLIT("toArrPRepr"),   buildToArrPRepr),
789              (FSLIT("fromArrPRepr"), buildFromArrPRepr),
790              (FSLIT("dictPRepr"),    buildPRDict)]
791
792 -- | Split the given tycons into two sets depending on whether they have to be
793 -- converted (first list) or not (second list). The first argument contains
794 -- information about the conversion status of external tycons:
795 -- 
796 --   * tycons which have converted versions are mapped to True
797 --   * tycons which are not changed by vectorisation are mapped to False
798 --   * tycons which can't be converted are not elements of the map
799 --
800 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
801 classifyTyCons = classify [] []
802   where
803     classify conv keep cs [] = (conv, keep)
804     classify conv keep cs ((tcs, ds) : rs)
805       | can_convert && must_convert
806         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
807       | can_convert
808         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
809       | otherwise
810         = classify conv keep cs rs
811       where
812         refs = ds `delListFromUniqSet` tcs
813
814         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
815         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
816
817         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
818     
819 -- | Compute mutually recursive groups of tycons in topological order
820 --
821 tyConGroups :: [TyCon] -> [TyConGroup]
822 tyConGroups tcs = map mk_grp (stronglyConnComp edges)
823   where
824     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
825                                 , let ds = tyConsOfTyCon tc]
826
827     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
828     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
829       where
830         (tcs, dss) = unzip els
831
832 tyConsOfTyCon :: TyCon -> UniqSet TyCon
833 tyConsOfTyCon 
834   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
835
836 tyConsOfType :: Type -> UniqSet TyCon
837 tyConsOfType ty
838   | Just ty' <- coreView ty    = tyConsOfType ty'
839 tyConsOfType (TyVarTy v)       = emptyUniqSet
840 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
841   where
842     extend | isUnLiftedTyCon tc
843            || isTupleTyCon   tc = id
844
845            | otherwise          = (`addOneToUniqSet` tc)
846
847 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
848 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
849                                  `addOneToUniqSet` funTyCon
850 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
851 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
852
853 tyConsOfTypes :: [Type] -> UniqSet TyCon
854 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
855