878dfab14b42dd03e5c538d628a75ec458d02da5
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
2                   -- arrSumArity, pdataCompTys, pdataCompVars,
3                   buildPADict,
4                   fromVect )
5 where
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import CoreUnfold
15 import MkCore            ( mkWildCase )
16 import BuildTyCl
17 import DataCon
18 import TyCon
19 import Type
20 import TypeRep
21 import Coercion
22 import FamInstEnv        ( FamInst, mkLocalFamInst )
23 import OccName
24 import Id
25 import MkId
26 import BasicTypes        ( HsBang(..), boolToRecFlag,
27                            alwaysInlinePragma, dfunInlinePragma )
28 import Var               ( Var, TyVar, varType )
29 import Name              ( Name, getOccName )
30 import NameEnv
31
32 import Unique
33 import UniqFM
34 import UniqSet
35 import Util
36 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
37
38 import Outputable
39 import FastString
40
41 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
42 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
43 import Data.List      ( inits, tails, zipWith4, zipWith5 )
44
45 -- ----------------------------------------------------------------------------
46 -- Types
47
48 -- | Vectorise a type constructor.
49 vectTyCon :: TyCon -> VM TyCon
50 vectTyCon tc
51   | isFunTyCon tc        = builtin closureTyCon
52   | isBoxedTupleTyCon tc = return tc
53   | isUnLiftedTyCon tc   = return tc
54   | otherwise            
55   = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
56         $ lookupTyCon tc
57
58
59 vectAndLiftType :: Type -> VM (Type, Type)
60 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
61 vectAndLiftType ty
62   = do
63       mdicts   <- mapM paDictArgType tyvars
64       let dicts = [dict | Just dict <- mdicts]
65       vmono_ty <- vectType mono_ty
66       lmono_ty <- mkPDataType vmono_ty
67       return (abstractType tyvars dicts vmono_ty,
68               abstractType tyvars dicts lmono_ty)
69   where
70     (tyvars, mono_ty) = splitForAllTys ty
71
72
73 -- | Vectorise a type.
74 vectType :: Type -> VM Type
75 vectType ty | Just ty' <- coreView ty = vectType ty'
76 vectType (TyVarTy tv) = return $ TyVarTy tv
77 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
78 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
79 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
80                                              (mapM vectAndBoxType [ty1,ty2])
81 vectType ty@(ForAllTy _ _)
82   = do
83       mdicts   <- mapM paDictArgType tyvars
84       mono_ty' <- vectType mono_ty
85       return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty'
86   where
87     (tyvars, mono_ty) = splitForAllTys ty
88
89 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
90
91 vectAndBoxType :: Type -> VM Type
92 vectAndBoxType ty = vectType ty >>= boxType
93
94 -- | Add quantified vars and dictionary parameters to the front of a type.
95 abstractType :: [TyVar] -> [Type] -> Type -> Type
96 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
97
98 -- ----------------------------------------------------------------------------
99 -- Boxing
100
101 boxType :: Type -> VM Type
102 boxType ty
103   | Just (tycon, []) <- splitTyConApp_maybe ty
104   , isUnLiftedTyCon tycon
105   = do
106       r <- lookupBoxedTyCon tycon
107       case r of
108         Just tycon' -> return $ mkTyConApp tycon' []
109         Nothing     -> return ty
110
111 boxType ty = return ty
112
113 -- ----------------------------------------------------------------------------
114 -- Type definitions
115
116 type TyConGroup = ([TyCon], UniqSet TyCon)
117
118 -- | Vectorise a type environment.
119 --   The type environment contains all the type things defined in a module.
120 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
121 vectTypeEnv env
122   = do
123       cs <- readGEnv $ mk_map . global_tycons
124
125       -- Split the list of TyCons into the ones we have to vectorise vs the
126       -- ones we can pass through unchanged. We also pass through algebraic 
127       -- types that use non Haskell98 features, as we don't handle those.
128       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
129           keep_dcs             = concatMap tyConDataCons keep_tcs
130       zipWithM_ defTyCon   keep_tcs keep_tcs
131       zipWithM_ defDataCon keep_dcs keep_dcs
132
133       new_tcs <- vectTyConDecls conv_tcs
134
135       let orig_tcs = keep_tcs ++ conv_tcs
136           vect_tcs = keep_tcs ++ new_tcs
137
138       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
139         do
140           defTyConPAs (zipLazy vect_tcs dfuns')
141           reprs <- mapM tyConRepr vect_tcs
142           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
143           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
144           dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs
145                                                           vect_tcs
146                                                           repr_tcs
147                                                           pdata_tcs
148                                                           reprs
149           binds <- takeHoisted
150           return (dfuns, binds, repr_tcs ++ pdata_tcs)
151
152       let all_new_tcs = new_tcs ++ inst_tcs
153
154       let new_env = extendTypeEnvList env
155                        (map ATyCon all_new_tcs
156                         ++ [ADataCon dc | tc <- all_new_tcs
157                                         , dc <- tyConDataCons tc])
158
159       return (new_env, map mkLocalFamInst inst_tcs, binds)
160   where
161     tycons = typeEnvTyCons env
162     groups = tyConGroups tycons
163
164     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
165
166
167 -- | Vectorise some (possibly recursively defined) type constructors.
168 vectTyConDecls :: [TyCon] -> VM [TyCon]
169 vectTyConDecls tcs = fixV $ \tcs' ->
170   do
171     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
172     mapM vectTyConDecl tcs
173
174 vectTyConDecl :: TyCon -> VM TyCon
175 vectTyConDecl tc
176   = do
177       name' <- cloneName mkVectTyConOcc name
178       rhs'  <- vectAlgTyConRhs tc (algTyConRhs tc)
179
180       liftDs $ buildAlgTyCon name'
181                              tyvars
182                              []           -- no stupid theta
183                              rhs'
184                              rec_flag     -- FIXME: is this ok?
185                              False        -- FIXME: no generics
186                              False        -- not GADT syntax
187                              Nothing      -- not a family instance
188   where
189     name   = tyConName tc
190     tyvars = tyConTyVars tc
191     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
192
193 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
194 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
195                              , is_enum   = is_enum
196                              })
197   = do
198       data_cons' <- mapM vectDataCon data_cons
199       zipWithM_ defDataCon data_cons data_cons'
200       return $ DataTyCon { data_cons = data_cons'
201                          , is_enum   = is_enum
202                          }
203 vectAlgTyConRhs tc _ = cantVectorise "Can't vectorise type definition:" (ppr tc)
204
205 vectDataCon :: DataCon -> VM DataCon
206 vectDataCon dc
207   | not . null $ dataConExTyVars dc
208         = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
209   | not . null $ dataConEqSpec   dc
210         = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
211   | otherwise
212   = do
213       name'    <- cloneName mkVectDataConOcc name
214       tycon'   <- vectTyCon tycon
215       arg_tys  <- mapM vectType rep_arg_tys
216
217       liftDs $ buildDataCon name'
218                             False           -- not infix
219                             (map (const HsNoBang) arg_tys)
220                             []              -- no labelled fields
221                             univ_tvs
222                             []              -- no existential tvs for now
223                             []              -- no eq spec for now
224                             []              -- no context
225                             arg_tys 
226                             (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs))
227                             tycon'
228   where
229     name        = dataConName dc
230     univ_tvs    = dataConUnivTyVars dc
231     rep_arg_tys = dataConRepArgTys dc
232     tycon       = dataConTyCon dc
233
234 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
235 mk_fam_inst fam_tc arg_tc
236   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
237
238
239 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
240 buildPReprTyCon orig_tc vect_tc repr
241   = do
242       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
243       -- rhs_ty   <- buildPReprType vect_tc
244       rhs_ty   <- sumReprType repr
245       prepr_tc <- builtin preprTyCon
246       liftDs $ buildSynTyCon name
247                              tyvars
248                              (SynonymTyCon rhs_ty)
249                              (typeKind rhs_ty)
250                              (Just $ mk_fam_inst prepr_tc vect_tc)
251   where
252     tyvars = tyConTyVars vect_tc
253
254 data CompRepr = Keep Type
255                      CoreExpr     -- PR dictionary for the type
256               | Wrap Type
257
258 data ProdRepr = EmptyProd
259               | UnaryProd CompRepr
260               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
261                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
262                      , repr_comp_tys :: [Type] -- representation types of
263                      , repr_comps    :: [CompRepr]          -- components
264                      }
265 data ConRepr  = ConRepr DataCon ProdRepr
266
267 data SumRepr  = EmptySum
268               | UnarySum ConRepr
269               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
270                      , repr_psum_tc  :: TyCon  -- PData representation tycon
271                      , repr_sel_ty   :: Type   -- type of selector
272                      , repr_con_tys :: [Type]  -- representation types of
273                      , repr_cons     :: [ConRepr]           -- components
274                      }
275
276 tyConRepr :: TyCon -> VM SumRepr
277 tyConRepr tc = sum_repr (tyConDataCons tc)
278   where
279     sum_repr []    = return EmptySum
280     sum_repr [con] = liftM UnarySum (con_repr con)
281     sum_repr cons  = do
282                        rs     <- mapM con_repr cons
283                        sum_tc <- builtin (sumTyCon arity)
284                        tys    <- mapM conReprType rs
285                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
286                        sel_ty <- builtin (selTy arity)
287                        return $ Sum { repr_sum_tc  = sum_tc
288                                     , repr_psum_tc = psum_tc
289                                     , repr_sel_ty  = sel_ty
290                                     , repr_con_tys = tys
291                                     , repr_cons    = rs
292                                     }
293       where
294         arity = length cons
295
296     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
297
298     prod_repr []   = return EmptyProd
299     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
300     prod_repr tys  = do
301                        rs <- mapM comp_repr tys
302                        tup_tc <- builtin (prodTyCon arity)
303                        tys'    <- mapM compReprType rs
304                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
305                        return $ Prod { repr_tup_tc   = tup_tc
306                                      , repr_ptup_tc  = ptup_tc
307                                      , repr_comp_tys = tys'
308                                      , repr_comps    = rs
309                                      }
310       where
311         arity = length tys
312     
313     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
314                    `orElseV` return (Wrap ty)
315
316 sumReprType :: SumRepr -> VM Type
317 sumReprType EmptySum = voidType
318 sumReprType (UnarySum r) = conReprType r
319 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
320   = return $ mkTyConApp sum_tc tys
321
322 conReprType :: ConRepr -> VM Type
323 conReprType (ConRepr _ r) = prodReprType r
324
325 prodReprType :: ProdRepr -> VM Type
326 prodReprType EmptyProd = voidType
327 prodReprType (UnaryProd r) = compReprType r
328 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
329   = return $ mkTyConApp tup_tc tys
330
331 compReprType :: CompRepr -> VM Type
332 compReprType (Keep ty _) = return ty
333 compReprType (Wrap ty) = do
334                              wrap_tc <- builtin wrapTyCon
335                              return $ mkTyConApp wrap_tc [ty]
336
337 compOrigType :: CompRepr -> Type
338 compOrigType (Keep ty _) = ty
339 compOrigType (Wrap ty) = ty
340
341 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
342 buildToPRepr vect_tc repr_tc _ repr
343   = do
344       let arg_ty = mkTyConApp vect_tc ty_args
345       res_ty <- mkPReprType arg_ty
346       arg    <- newLocalVar (fsLit "x") arg_ty
347       result <- to_sum (Var arg) arg_ty res_ty repr
348       return $ Lam arg result
349   where
350     ty_args = mkTyVarTys (tyConTyVars vect_tc)
351
352     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
353
354     to_sum _ _ _ EmptySum
355       = do
356           void <- builtin voidVar
357           return $ wrap_repr_inst $ Var void
358
359     to_sum arg arg_ty res_ty (UnarySum r)
360       = do
361           (pat, vars, body) <- con_alt r
362           return $ mkWildCase arg arg_ty res_ty
363                    [(pat, vars, wrap_repr_inst body)]
364
365     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
366                                   , repr_con_tys = tys
367                                   , repr_cons    =  cons })
368       = do
369           alts <- mapM con_alt cons
370           let alts' = [(pat, vars, wrap_repr_inst
371                                    $ mkConApp sum_con (map Type tys ++ [body]))
372                         | ((pat, vars, body), sum_con)
373                             <- zip alts (tyConDataCons sum_tc)]
374           return $ mkWildCase arg arg_ty res_ty alts'
375
376     con_alt (ConRepr con r)
377       = do
378           (vars, body) <- to_prod r
379           return (DataAlt con, vars, body)
380
381     to_prod EmptyProd
382       = do
383           void <- builtin voidVar
384           return ([], Var void)
385
386     to_prod (UnaryProd comp)
387       = do
388           var  <- newLocalVar (fsLit "x") (compOrigType comp)
389           body <- to_comp (Var var) comp
390           return ([var], body)
391
392     to_prod(Prod { repr_tup_tc   = tup_tc
393                  , repr_comp_tys = tys
394                  , repr_comps    = comps })
395       = do
396           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
397           exprs <- zipWithM to_comp (map Var vars) comps
398           return (vars, mkConApp tup_con (map Type tys ++ exprs))
399       where
400         [tup_con] = tyConDataCons tup_tc
401
402     to_comp expr (Keep _ _) = return expr
403     to_comp expr (Wrap ty)  = do
404                                 wrap_tc <- builtin wrapTyCon
405                                 return $ wrapNewTypeBody wrap_tc [ty] expr
406
407
408 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
409 buildFromPRepr vect_tc repr_tc _ repr
410   = do
411       arg_ty <- mkPReprType res_ty
412       arg <- newLocalVar (fsLit "x") arg_ty
413
414       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
415                          repr
416       return $ Lam arg result
417   where
418     ty_args = mkTyVarTys (tyConTyVars vect_tc)
419     res_ty  = mkTyConApp vect_tc ty_args
420
421     from_sum _ EmptySum
422       = do
423           dummy <- builtin fromVoidVar
424           return $ Var dummy `App` Type res_ty
425
426     from_sum expr (UnarySum r) = from_con expr r
427     from_sum expr (Sum { repr_sum_tc  = sum_tc
428                        , repr_con_tys = tys
429                        , repr_cons    = cons })
430       = do
431           vars  <- newLocalVars (fsLit "x") tys
432           es    <- zipWithM from_con (map Var vars) cons
433           return $ mkWildCase expr (exprType expr) res_ty
434                    [(DataAlt con, [var], e)
435                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
436
437     from_con expr (ConRepr con r)
438       = from_prod expr (mkConApp con $ map Type ty_args) r
439
440     from_prod _ con EmptyProd = return con
441     from_prod expr con (UnaryProd r)
442       = do
443           e <- from_comp expr r
444           return $ con `App` e
445      
446     from_prod expr con (Prod { repr_tup_tc   = tup_tc
447                              , repr_comp_tys = tys
448                              , repr_comps    = comps
449                              })
450       = do
451           vars <- newLocalVars (fsLit "y") tys
452           es   <- zipWithM from_comp (map Var vars) comps
453           return $ mkWildCase expr (exprType expr) res_ty
454                    [(DataAlt tup_con, vars, con `mkApps` es)]
455       where
456         [tup_con] = tyConDataCons tup_tc  
457
458     from_comp expr (Keep _ _) = return expr
459     from_comp expr (Wrap ty)
460       = do
461           wrap <- builtin wrapTyCon
462           return $ unwrapNewTypeBody wrap [ty] expr
463
464
465 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
466 buildToArrPRepr vect_tc prepr_tc pdata_tc r
467   = do
468       arg_ty <- mkPDataType el_ty
469       res_ty <- mkPDataType =<< mkPReprType el_ty
470       arg    <- newLocalVar (fsLit "xs") arg_ty
471
472       pdata_co <- mkBuiltinCo pdataTyCon
473       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
474           co           = mkAppCoercion pdata_co
475                        . mkSymCoercion
476                        $ mkTyConApp repr_co ty_args
477
478           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
479
480       (vars, result) <- to_sum r
481
482       return . Lam arg
483              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
484                [(DataAlt pdata_dc, vars, mkCoerce co result)]
485   where
486     ty_args = mkTyVarTys $ tyConTyVars vect_tc
487     el_ty   = mkTyConApp vect_tc ty_args
488
489     [pdata_dc] = tyConDataCons pdata_tc
490
491
492     to_sum EmptySum = do
493                         pvoid <- builtin pvoidVar
494                         return ([], Var pvoid)
495     to_sum (UnarySum r) = to_con r
496     to_sum (Sum { repr_psum_tc = psum_tc
497                 , repr_sel_ty  = sel_ty
498                 , repr_con_tys = tys
499                 , repr_cons    = cons
500                 })
501       = do
502           (vars, exprs) <- mapAndUnzipM to_con cons
503           sel <- newLocalVar (fsLit "sel") sel_ty
504           return (sel : concat vars, mk_result (Var sel) exprs)
505       where
506         [psum_con] = tyConDataCons psum_tc
507         mk_result sel exprs = wrapFamInstBody psum_tc tys
508                             $ mkConApp psum_con
509                             $ map Type tys ++ (sel : exprs)
510
511     to_con (ConRepr _ r) = to_prod r
512
513     to_prod EmptyProd = do
514                           pvoid <- builtin pvoidVar
515                           return ([], Var pvoid)
516     to_prod (UnaryProd r)
517       = do
518           pty  <- mkPDataType (compOrigType r)
519           var  <- newLocalVar (fsLit "x") pty
520           expr <- to_comp (Var var) r
521           return ([var], expr)
522
523     to_prod (Prod { repr_ptup_tc  = ptup_tc
524                   , repr_comp_tys = tys
525                   , repr_comps    = comps })
526       = do
527           ptys <- mapM (mkPDataType . compOrigType) comps
528           vars <- newLocalVars (fsLit "x") ptys
529           es   <- zipWithM to_comp (map Var vars) comps
530           return (vars, mk_result es)
531       where
532         [ptup_con] = tyConDataCons ptup_tc
533         mk_result exprs = wrapFamInstBody ptup_tc tys
534                         $ mkConApp ptup_con
535                         $ map Type tys ++ exprs
536
537     to_comp expr (Keep _ _) = return expr
538
539     -- FIXME: this is bound to be wrong!
540     to_comp expr (Wrap ty)
541       = do
542           wrap_tc  <- builtin wrapTyCon
543           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
544           return $ wrapNewTypeBody pwrap_tc [ty] expr
545
546
547 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
548 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
549   = do
550       arg_ty <- mkPDataType =<< mkPReprType el_ty
551       res_ty <- mkPDataType el_ty
552       arg    <- newLocalVar (fsLit "xs") arg_ty
553
554       pdata_co <- mkBuiltinCo pdataTyCon
555       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
556           co           = mkAppCoercion pdata_co
557                        $ mkTyConApp repr_co var_tys
558
559           scrut  = mkCoerce co (Var arg)
560
561           mk_result args = wrapFamInstBody pdata_tc var_tys
562                          $ mkConApp pdata_con
563                          $ map Type var_tys ++ args
564
565       (expr, _) <- fixV $ \ ~(_, args) ->
566                      from_sum res_ty (mk_result args) scrut r
567
568       return $ Lam arg expr
569     
570       -- (args, mk) <- from_sum res_ty scrut r
571       
572       -- let result = wrapFamInstBody pdata_tc var_tys
573       --           . mkConApp pdata_dc
574       --           $ map Type var_tys ++ args
575
576       -- return $ Lam arg (mk result)
577   where
578     var_tys = mkTyVarTys $ tyConTyVars vect_tc
579     el_ty   = mkTyConApp vect_tc var_tys
580
581     [pdata_con] = tyConDataCons pdata_tc
582
583     from_sum _ res _ EmptySum = return (res, [])
584     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
585     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
586                                   , repr_sel_ty  = sel_ty
587                                   , repr_con_tys = tys
588                                   , repr_cons    = cons })
589       = do
590           sel  <- newLocalVar (fsLit "sel") sel_ty
591           ptys <- mapM mkPDataType tys
592           vars <- newLocalVars (fsLit "xs") ptys
593           (res', args) <- fold from_con res_ty res (map Var vars) cons
594           let scrut = unwrapFamInstScrut psum_tc tys expr
595               body  = mkWildCase scrut (exprType scrut) res_ty
596                       [(DataAlt psum_con, sel : vars, res')]
597           return (body, Var sel : args)
598       where
599         [psum_con] = tyConDataCons psum_tc
600
601
602     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
603
604     from_prod _ res _ EmptyProd = return (res, [])
605     from_prod res_ty res expr (UnaryProd r)
606       = from_comp res_ty res expr r
607     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
608                                     , repr_comp_tys = tys
609                                     , repr_comps    = comps })
610       = do
611           ptys <- mapM mkPDataType tys
612           vars <- newLocalVars (fsLit "ys") ptys
613           (res', args) <- fold from_comp res_ty res (map Var vars) comps
614           let scrut = unwrapFamInstScrut ptup_tc tys expr
615               body  = mkWildCase scrut (exprType scrut) res_ty
616                       [(DataAlt ptup_con, vars, res')]
617           return (body, args)
618       where
619         [ptup_con] = tyConDataCons ptup_tc
620
621     from_comp _ res expr (Keep _ _) = return (res, [expr])
622     from_comp _ res expr (Wrap ty)
623       = do
624           wrap_tc  <- builtin wrapTyCon
625           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
626           return (res, [unwrapNewTypeBody pwrap_tc [ty]
627                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
628
629     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
630       where
631         f' (expr, r) (res, args) = do
632                                      (res', args') <- f res_ty res expr r
633                                      return (res', args' ++ args)
634
635 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
636 buildPRDict vect_tc prepr_tc _ r
637   = do
638       dict <- sum_dict r
639       pr_co <- mkBuiltinCo prTyCon
640       let co = mkAppCoercion pr_co
641              . mkSymCoercion
642              $ mkTyConApp arg_co ty_args
643       return (mkCoerce co dict)
644   where
645     ty_args = mkTyVarTys (tyConTyVars vect_tc)
646     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
647
648     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
649     sum_dict (UnarySum r) = con_dict r
650     sum_dict (Sum { repr_sum_tc  = sum_tc
651                   , repr_con_tys = tys
652                   , repr_cons    = cons
653                   })
654       = do
655           dicts <- mapM con_dict cons
656           dfun  <- prDFunOfTyCon sum_tc
657           return $ dfun `mkTyApps` tys `mkApps` dicts
658
659     con_dict (ConRepr _ r) = prod_dict r
660
661     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
662     prod_dict (UnaryProd r) = comp_dict r
663     prod_dict (Prod { repr_tup_tc   = tup_tc
664                     , repr_comp_tys = tys
665                     , repr_comps    = comps })
666       = do
667           dicts <- mapM comp_dict comps
668           dfun <- prDFunOfTyCon tup_tc
669           return $ dfun `mkTyApps` tys `mkApps` dicts
670
671     comp_dict (Keep _ pr) = return pr
672     comp_dict (Wrap ty)   = wrapPR ty
673
674
675 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
676 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
677   do
678     name' <- cloneName mkPDataTyConOcc orig_name
679     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
680     pdata <- builtin pdataTyCon
681
682     liftDs $ buildAlgTyCon name'
683                            tyvars
684                            []          -- no stupid theta
685                            rhs
686                            rec_flag    -- FIXME: is this ok?
687                            False       -- FIXME: no generics
688                            False       -- not GADT syntax
689                            (Just $ mk_fam_inst pdata vect_tc)
690   where
691     orig_name = tyConName orig_tc
692     tyvars = tyConTyVars vect_tc
693     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
694
695
696 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
697 buildPDataTyConRhs orig_name vect_tc repr_tc repr
698   = do
699       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
700       return $ DataTyCon { data_cons = [data_con], is_enum = False }
701
702 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
703 buildPDataDataCon orig_name vect_tc repr_tc repr
704   = do
705       dc_name  <- cloneName mkPDataDataConOcc orig_name
706       comp_tys <- sum_tys repr
707
708       liftDs $ buildDataCon dc_name
709                             False                  -- not infix
710                             (map (const HsNoBang) comp_tys)
711                             []                     -- no field labels
712                             tvs
713                             []                     -- no existentials
714                             []                     -- no eq spec
715                             []                     -- no context
716                             comp_tys
717                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
718                             repr_tc
719   where
720     tvs   = tyConTyVars vect_tc
721
722     sum_tys EmptySum = return []
723     sum_tys (UnarySum r) = con_tys r
724     sum_tys (Sum { repr_sel_ty = sel_ty
725                  , repr_cons   = cons })
726       = liftM (sel_ty :) (concatMapM con_tys cons)
727
728     con_tys (ConRepr _ r) = prod_tys r
729
730     prod_tys EmptyProd = return []
731     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
732     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
733
734     comp_ty r = mkPDataType (compOrigType r)
735
736
737 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
738                    -> VM Var
739 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
740   = do
741       vectDataConWorkers orig_tc vect_tc pdata_tc
742       buildPADict vect_tc prepr_tc pdata_tc repr
743
744 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
745 vectDataConWorkers orig_tc vect_tc arr_tc
746   = do
747       bs <- sequence
748           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
749           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
750                                  rep_tys
751                                  (inits rep_tys)
752                                  (tail $ tails rep_tys)
753       mapM_ (uncurry hoistBinding) bs
754   where
755     tyvars   = tyConTyVars vect_tc
756     var_tys  = mkTyVarTys tyvars
757     ty_args  = map Type var_tys
758     res_ty   = mkTyConApp vect_tc var_tys
759
760     cons     = tyConDataCons vect_tc
761     arity    = length cons
762     [arr_dc] = tyConDataCons arr_tc
763
764     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
765
766
767     mk_data_con con tys pre post
768       = liftM2 (,) (vect_data_con con)
769                    (lift_data_con tys pre post (mkDataConTag con))
770
771     sel_replicate len tag
772       | arity > 1 = do
773                       rep <- builtin (selReplicate arity)
774                       return [rep `mkApps` [len, tag]]
775
776       | otherwise = return []
777
778     vect_data_con con = return $ mkConApp con ty_args
779     lift_data_con tys pre_tys post_tys tag
780       = do
781           len  <- builtin liftingContext
782           args <- mapM (newLocalVar (fsLit "xs"))
783                   =<< mapM mkPDataType tys
784
785           sel  <- sel_replicate (Var len) tag
786
787           pre   <- mapM emptyPD (concat pre_tys)
788           post  <- mapM emptyPD (concat post_tys)
789
790           return . mkLams (len : args)
791                  . wrapFamInstBody arr_tc var_tys
792                  . mkConApp arr_dc
793                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
794
795     def_worker data_con arg_tys mk_body
796       = do
797           arity <- polyArity tyvars
798           body <- closedV
799                 . inBind orig_worker
800                 . polyAbstract tyvars $ \args ->
801                   liftM (mkLams (tyvars ++ args) . vectorised)
802                 $ buildClosures tyvars [] arg_tys res_ty mk_body
803
804           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
805           let vect_worker = raw_worker `setIdUnfolding`
806                               mkInlineRule body (Just arity)
807           defGlobalVar orig_worker vect_worker
808           return (vect_worker, body)
809       where
810         orig_worker = dataConWorkId data_con
811
812 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
813 buildPADict vect_tc prepr_tc arr_tc repr
814   = polyAbstract tvs $ \args ->
815     do
816       method_ids <- mapM (method args) paMethods
817
818       pa_tc  <- builtin paTyCon
819       pa_dc  <- builtin paDataCon
820       let dict = mkLams (tvs ++ args)
821                $ mkConApp pa_dc
822                $ Type inst_ty : map (method_call args) method_ids
823
824           dfun_ty = mkForAllTys tvs
825                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
826
827       raw_dfun <- newExportedVar dfun_name dfun_ty
828       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
829                           `setInlinePragma` dfunInlinePragma
830
831       hoistBinding dfun dict
832       return dfun
833   where
834     tvs = tyConTyVars vect_tc
835     arg_tys = mkTyVarTys tvs
836     inst_ty = mkTyConApp vect_tc arg_tys
837
838     dfun_name = mkPADFunOcc (getOccName vect_tc)
839
840     method args (name, build)
841       = localV
842       $ do
843           expr <- build vect_tc prepr_tc arr_tc repr
844           let body = mkLams (tvs ++ args) expr
845           raw_var <- newExportedVar (method_name name) (exprType body)
846           let var = raw_var
847                       `setIdUnfolding` mkInlineRule body (Just (length args))
848                       `setInlinePragma` alwaysInlinePragma
849           hoistBinding var body
850           return var
851
852     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
853
854     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
855
856
857 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
858 paMethods = [("dictPRepr",    buildPRDict),
859              ("toPRepr",      buildToPRepr),
860              ("fromPRepr",    buildFromPRepr),
861              ("toArrPRepr",   buildToArrPRepr),
862              ("fromArrPRepr", buildFromArrPRepr)]
863
864 -- | Split the given tycons into two sets depending on whether they have to be
865 --   converted (first list) or not (second list). The first argument contains
866 --   information about the conversion status of external tycons:
867 --
868 --   * tycons which have converted versions are mapped to True
869 --   * tycons which are not changed by vectorisation are mapped to False
870 --   * tycons which can't be converted are not elements of the map
871 --
872 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
873 classifyTyCons = classify [] []
874   where
875     classify conv keep _  [] = (conv, keep)
876     classify conv keep cs ((tcs, ds) : rs)
877       | can_convert && must_convert
878         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
879       | can_convert
880         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
881       | otherwise
882         = classify conv keep cs rs
883       where
884         refs = ds `delListFromUniqSet` tcs
885
886         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
887         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
888
889         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
890
891 -- | Compute mutually recursive groups of tycons in topological order
892 --
893 tyConGroups :: [TyCon] -> [TyConGroup]
894 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
895   where
896     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
897                                 , let ds = tyConsOfTyCon tc]
898
899     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
900     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
901       where
902         (tcs, dss) = unzip els
903
904 tyConsOfTyCon :: TyCon -> UniqSet TyCon
905 tyConsOfTyCon
906   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
907
908 tyConsOfType :: Type -> UniqSet TyCon
909 tyConsOfType ty
910   | Just ty' <- coreView ty    = tyConsOfType ty'
911 tyConsOfType (TyVarTy _)       = emptyUniqSet
912 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
913   where
914     extend | isUnLiftedTyCon tc
915            || isTupleTyCon   tc = id
916
917            | otherwise          = (`addOneToUniqSet` tc)
918
919 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
920 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
921                                  `addOneToUniqSet` funTyCon
922 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
923 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
924
925 tyConsOfTypes :: [Type] -> UniqSet TyCon
926 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
927
928
929 -- ----------------------------------------------------------------------------
930 -- Conversions
931
932 fromVect :: Type -> CoreExpr -> VM CoreExpr
933 fromVect ty expr | Just ty' <- coreView ty = fromVect ty' expr
934 fromVect (FunTy arg_ty res_ty) expr
935   = do
936       arg     <- newLocalVar (fsLit "x") arg_ty
937       varg    <- toVect arg_ty (Var arg)
938       varg_ty <- vectType arg_ty
939       vres_ty <- vectType res_ty
940       apply   <- builtin applyVar
941       body    <- fromVect res_ty
942                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
943       return $ Lam arg body
944 fromVect ty expr
945   = identityConv ty >> return expr
946
947 toVect :: Type -> CoreExpr -> VM CoreExpr
948 toVect ty expr = identityConv ty >> return expr
949
950 identityConv :: Type -> VM ()
951 identityConv ty | Just ty' <- coreView ty = identityConv ty'
952 identityConv (TyConApp tycon tys)
953   = do
954       mapM_ identityConv tys
955       identityConvTyCon tycon
956 identityConv _ = noV
957
958 identityConvTyCon :: TyCon -> VM ()
959 identityConvTyCon tc
960   | isBoxedTupleTyCon tc = return ()
961   | isUnLiftedTyCon tc   = return ()
962   | otherwise            = do
963                              tc' <- maybeV (lookupTyCon tc)
964                              if tc == tc' then return () else noV
965