Fix warnings
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
2                   -- arrSumArity, pdataCompTys, pdataCompVars,
3                   buildPADict,
4                   fromVect )
5 where
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import MkCore            ( mkWildCase )
15 import BuildTyCl
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Coercion
21 import FamInstEnv        ( FamInst, mkLocalFamInst )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), boolToRecFlag )
25 import Var               ( Var, TyVar )
26 import Name              ( Name, getOccName )
27 import NameEnv
28
29 import Unique
30 import UniqFM
31 import UniqSet
32 import Util
33 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
34
35 import Outputable
36 import FastString
37
38 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
39 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
40 import Data.List      ( inits, tails, zipWith4, zipWith6 )
41
42 -- ----------------------------------------------------------------------------
43 -- Types
44
45 vectTyCon :: TyCon -> VM TyCon
46 vectTyCon tc
47   | isFunTyCon tc        = builtin closureTyCon
48   | isBoxedTupleTyCon tc = return tc
49   | isUnLiftedTyCon tc   = return tc
50   | otherwise            = maybeCantVectoriseM "Tycon not vectorised:" (ppr tc)
51                          $ lookupTyCon tc
52
53 vectAndLiftType :: Type -> VM (Type, Type)
54 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
55 vectAndLiftType ty
56   = do
57       mdicts   <- mapM paDictArgType tyvars
58       let dicts = [dict | Just dict <- mdicts]
59       vmono_ty <- vectType mono_ty
60       lmono_ty <- mkPDataType vmono_ty
61       return (abstractType tyvars dicts vmono_ty,
62               abstractType tyvars dicts lmono_ty)
63   where
64     (tyvars, mono_ty) = splitForAllTys ty
65
66
67 vectType :: Type -> VM Type
68 vectType ty | Just ty' <- coreView ty = vectType ty'
69 vectType (TyVarTy tv) = return $ TyVarTy tv
70 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
71 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
72 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
73                                              (mapM vectAndBoxType [ty1,ty2])
74 vectType ty@(ForAllTy _ _)
75   = do
76       mdicts   <- mapM paDictArgType tyvars
77       mono_ty' <- vectType mono_ty
78       return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty'
79   where
80     (tyvars, mono_ty) = splitForAllTys ty
81
82 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
83
84 vectAndBoxType :: Type -> VM Type
85 vectAndBoxType ty = vectType ty >>= boxType
86
87 abstractType :: [TyVar] -> [Type] -> Type -> Type
88 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
89
90 -- ----------------------------------------------------------------------------
91 -- Boxing
92
93 boxType :: Type -> VM Type
94 boxType ty
95   | Just (tycon, []) <- splitTyConApp_maybe ty
96   , isUnLiftedTyCon tycon
97   = do
98       r <- lookupBoxedTyCon tycon
99       case r of
100         Just tycon' -> return $ mkTyConApp tycon' []
101         Nothing     -> return ty
102 boxType ty = return ty
103
104 -- ----------------------------------------------------------------------------
105 -- Type definitions
106
107 type TyConGroup = ([TyCon], UniqSet TyCon)
108
109 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
110 vectTypeEnv env
111   = do
112       cs <- readGEnv $ mk_map . global_tycons
113       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
114           keep_dcs             = concatMap tyConDataCons keep_tcs
115       zipWithM_ defTyCon   keep_tcs keep_tcs
116       zipWithM_ defDataCon keep_dcs keep_dcs
117       new_tcs <- vectTyConDecls conv_tcs
118
119       let orig_tcs = keep_tcs ++ conv_tcs
120           vect_tcs = keep_tcs ++ new_tcs
121
122       dfuns <- mapM mkPADFun vect_tcs
123       defTyConPAs (zip vect_tcs dfuns)
124       reprs <- mapM tyConRepr vect_tcs
125       repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
126       pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
127       binds    <- sequence (zipWith6 buildTyConBindings orig_tcs
128                                                         vect_tcs
129                                                         repr_tcs
130                                                         pdata_tcs
131                                                         dfuns
132                                                         reprs)
133
134       let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs
135
136       let new_env = extendTypeEnvList env
137                        (map ATyCon all_new_tcs
138                         ++ [ADataCon dc | tc <- all_new_tcs
139                                         , dc <- tyConDataCons tc])
140
141       return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds)
142   where
143     tycons = typeEnvTyCons env
144     groups = tyConGroups tycons
145
146     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
147
148
149 vectTyConDecls :: [TyCon] -> VM [TyCon]
150 vectTyConDecls tcs = fixV $ \tcs' ->
151   do
152     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
153     mapM vectTyConDecl tcs
154
155 vectTyConDecl :: TyCon -> VM TyCon
156 vectTyConDecl tc
157   = do
158       name' <- cloneName mkVectTyConOcc name
159       rhs'  <- vectAlgTyConRhs tc (algTyConRhs tc)
160
161       liftDs $ buildAlgTyCon name'
162                              tyvars
163                              []           -- no stupid theta
164                              rhs'
165                              rec_flag     -- FIXME: is this ok?
166                              False        -- FIXME: no generics
167                              False        -- not GADT syntax
168                              Nothing      -- not a family instance
169   where
170     name   = tyConName tc
171     tyvars = tyConTyVars tc
172     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
173
174 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
175 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
176                              , is_enum   = is_enum
177                              })
178   = do
179       data_cons' <- mapM vectDataCon data_cons
180       zipWithM_ defDataCon data_cons data_cons'
181       return $ DataTyCon { data_cons = data_cons'
182                          , is_enum   = is_enum
183                          }
184 vectAlgTyConRhs tc _ = cantVectorise "Can't vectorise type definition:" (ppr tc)
185
186 vectDataCon :: DataCon -> VM DataCon
187 vectDataCon dc
188   | not . null $ dataConExTyVars dc
189         = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
190   | not . null $ dataConEqSpec   dc
191         = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
192   | otherwise
193   = do
194       name'    <- cloneName mkVectDataConOcc name
195       tycon'   <- vectTyCon tycon
196       arg_tys  <- mapM vectType rep_arg_tys
197
198       liftDs $ buildDataCon name'
199                             False           -- not infix
200                             (map (const NotMarkedStrict) arg_tys)
201                             []              -- no labelled fields
202                             univ_tvs
203                             []              -- no existential tvs for now
204                             []              -- no eq spec for now
205                             []              -- no context
206                             arg_tys 
207                             (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs))
208                             tycon'
209   where
210     name        = dataConName dc
211     univ_tvs    = dataConUnivTyVars dc
212     rep_arg_tys = dataConRepArgTys dc
213     tycon       = dataConTyCon dc
214
215 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
216 mk_fam_inst fam_tc arg_tc
217   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
218
219
220 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
221 buildPReprTyCon orig_tc vect_tc repr
222   = do
223       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
224       -- rhs_ty   <- buildPReprType vect_tc
225       rhs_ty   <- sumReprType repr
226       prepr_tc <- builtin preprTyCon
227       liftDs $ buildSynTyCon name
228                              tyvars
229                              (SynonymTyCon rhs_ty)
230                              (typeKind rhs_ty)
231                              (Just $ mk_fam_inst prepr_tc vect_tc)
232   where
233     tyvars = tyConTyVars vect_tc
234
235 data CompRepr = Keep Type
236                      CoreExpr     -- PR dictionary for the type
237               | Wrap Type
238
239 data ProdRepr = EmptyProd
240               | UnaryProd CompRepr
241               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
242                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
243                      , repr_comp_tys :: [Type] -- representation types of
244                      , repr_comps    :: [CompRepr]          -- components
245                      }
246 data ConRepr  = ConRepr DataCon ProdRepr
247
248 data SumRepr  = EmptySum
249               | UnarySum ConRepr
250               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
251                      , repr_psum_tc  :: TyCon  -- PData representation tycon
252                      , repr_sel_ty   :: Type   -- type of selector
253                      , repr_con_tys :: [Type]  -- representation types of
254                      , repr_cons     :: [ConRepr]           -- components
255                      }
256
257 tyConRepr :: TyCon -> VM SumRepr
258 tyConRepr tc = sum_repr (tyConDataCons tc)
259   where
260     sum_repr []    = return EmptySum
261     sum_repr [con] = liftM UnarySum (con_repr con)
262     sum_repr cons  = do
263                        rs     <- mapM con_repr cons
264                        sum_tc <- builtin (sumTyCon arity)
265                        tys    <- mapM conReprType rs
266                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
267                        sel_ty <- builtin (selTy arity)
268                        return $ Sum { repr_sum_tc  = sum_tc
269                                     , repr_psum_tc = psum_tc
270                                     , repr_sel_ty  = sel_ty
271                                     , repr_con_tys = tys
272                                     , repr_cons    = rs
273                                     }
274       where
275         arity = length cons
276
277     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
278
279     prod_repr []   = return EmptyProd
280     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
281     prod_repr tys  = do
282                        rs <- mapM comp_repr tys
283                        tup_tc <- builtin (prodTyCon arity)
284                        tys'    <- mapM compReprType rs
285                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
286                        return $ Prod { repr_tup_tc   = tup_tc
287                                      , repr_ptup_tc  = ptup_tc
288                                      , repr_comp_tys = tys'
289                                      , repr_comps    = rs
290                                      }
291       where
292         arity = length tys
293     
294     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
295                    `orElseV` return (Wrap ty)
296
297 sumReprType :: SumRepr -> VM Type
298 sumReprType EmptySum = voidType
299 sumReprType (UnarySum r) = conReprType r
300 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
301   = return $ mkTyConApp sum_tc tys
302
303 conReprType :: ConRepr -> VM Type
304 conReprType (ConRepr _ r) = prodReprType r
305
306 prodReprType :: ProdRepr -> VM Type
307 prodReprType EmptyProd = voidType
308 prodReprType (UnaryProd r) = compReprType r
309 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
310   = return $ mkTyConApp tup_tc tys
311
312 compReprType :: CompRepr -> VM Type
313 compReprType (Keep ty _) = return ty
314 compReprType (Wrap ty) = do
315                              wrap_tc <- builtin wrapTyCon
316                              return $ mkTyConApp wrap_tc [ty]
317
318 compOrigType :: CompRepr -> Type
319 compOrigType (Keep ty _) = ty
320 compOrigType (Wrap ty) = ty
321
322 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
323 buildToPRepr vect_tc repr_tc _ repr
324   = do
325       let arg_ty = mkTyConApp vect_tc ty_args
326       res_ty <- mkPReprType arg_ty
327       arg    <- newLocalVar (fsLit "x") arg_ty
328       result <- to_sum (Var arg) arg_ty res_ty repr
329       return $ Lam arg result
330   where
331     ty_args = mkTyVarTys (tyConTyVars vect_tc)
332
333     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
334
335     to_sum _ _ _ EmptySum
336       = do
337           void <- builtin voidVar
338           return $ wrap_repr_inst $ Var void
339
340     to_sum arg arg_ty res_ty (UnarySum r)
341       = do
342           (pat, vars, body) <- con_alt r
343           return $ mkWildCase arg arg_ty res_ty
344                    [(pat, vars, wrap_repr_inst body)]
345
346     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
347                                   , repr_con_tys = tys
348                                   , repr_cons    =  cons })
349       = do
350           alts <- mapM con_alt cons
351           let alts' = [(pat, vars, wrap_repr_inst
352                                    $ mkConApp sum_con (map Type tys ++ [body]))
353                         | ((pat, vars, body), sum_con)
354                             <- zip alts (tyConDataCons sum_tc)]
355           return $ mkWildCase arg arg_ty res_ty alts'
356
357     con_alt (ConRepr con r)
358       = do
359           (vars, body) <- to_prod r
360           return (DataAlt con, vars, body)
361
362     to_prod EmptyProd
363       = do
364           void <- builtin voidVar
365           return ([], Var void)
366
367     to_prod (UnaryProd comp)
368       = do
369           var  <- newLocalVar (fsLit "x") (compOrigType comp)
370           body <- to_comp (Var var) comp
371           return ([var], body)
372
373     to_prod(Prod { repr_tup_tc   = tup_tc
374                  , repr_comp_tys = tys
375                  , repr_comps    = comps })
376       = do
377           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
378           exprs <- zipWithM to_comp (map Var vars) comps
379           return (vars, mkConApp tup_con (map Type tys ++ exprs))
380       where
381         [tup_con] = tyConDataCons tup_tc
382
383     to_comp expr (Keep _ _) = return expr
384     to_comp expr (Wrap ty)  = do
385                                 wrap_tc <- builtin wrapTyCon
386                                 return $ wrapNewTypeBody wrap_tc [ty] expr
387
388
389 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
390 buildFromPRepr vect_tc repr_tc _ repr
391   = do
392       arg_ty <- mkPReprType res_ty
393       arg <- newLocalVar (fsLit "x") arg_ty
394
395       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
396                          repr
397       return $ Lam arg result
398   where
399     ty_args = mkTyVarTys (tyConTyVars vect_tc)
400     res_ty  = mkTyConApp vect_tc ty_args
401
402     from_sum _ EmptySum
403       = do
404           dummy <- builtin fromVoidVar
405           return $ Var dummy `App` Type res_ty
406
407     from_sum expr (UnarySum r) = from_con expr r
408     from_sum expr (Sum { repr_sum_tc  = sum_tc
409                        , repr_con_tys = tys
410                        , repr_cons    = cons })
411       = do
412           vars  <- newLocalVars (fsLit "x") tys
413           es    <- zipWithM from_con (map Var vars) cons
414           return $ mkWildCase expr (exprType expr) res_ty
415                    [(DataAlt con, [var], e)
416                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
417
418     from_con expr (ConRepr con r)
419       = from_prod expr (mkConApp con $ map Type ty_args) r
420
421     from_prod _ con EmptyProd = return con
422     from_prod expr con (UnaryProd r)
423       = do
424           e <- from_comp expr r
425           return $ con `App` e
426      
427     from_prod expr con (Prod { repr_tup_tc   = tup_tc
428                              , repr_comp_tys = tys
429                              , repr_comps    = comps
430                              })
431       = do
432           vars <- newLocalVars (fsLit "y") tys
433           es   <- zipWithM from_comp (map Var vars) comps
434           return $ mkWildCase expr (exprType expr) res_ty
435                    [(DataAlt tup_con, vars, con `mkApps` es)]
436       where
437         [tup_con] = tyConDataCons tup_tc  
438
439     from_comp expr (Keep _ _) = return expr
440     from_comp expr (Wrap ty)
441       = do
442           wrap <- builtin wrapTyCon
443           return $ unwrapNewTypeBody wrap [ty] expr
444
445
446 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
447 buildToArrPRepr vect_tc prepr_tc pdata_tc r
448   = do
449       arg_ty <- mkPDataType el_ty
450       res_ty <- mkPDataType =<< mkPReprType el_ty
451       arg    <- newLocalVar (fsLit "xs") arg_ty
452
453       pdata_co <- mkBuiltinCo pdataTyCon
454       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
455           co           = mkAppCoercion pdata_co
456                        . mkSymCoercion
457                        $ mkTyConApp repr_co ty_args
458
459           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
460
461       (vars, result) <- to_sum r
462
463       return . Lam arg
464              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
465                [(DataAlt pdata_dc, vars, mkCoerce co result)]
466   where
467     ty_args = mkTyVarTys $ tyConTyVars vect_tc
468     el_ty   = mkTyConApp vect_tc ty_args
469
470     [pdata_dc] = tyConDataCons pdata_tc
471
472
473     to_sum EmptySum = do
474                         pvoid <- builtin pvoidVar
475                         return ([], Var pvoid)
476     to_sum (UnarySum r) = to_con r
477     to_sum (Sum { repr_psum_tc = psum_tc
478                 , repr_sel_ty  = sel_ty
479                 , repr_con_tys = tys
480                 , repr_cons    = cons
481                 })
482       = do
483           (vars, exprs) <- mapAndUnzipM to_con cons
484           sel <- newLocalVar (fsLit "sel") sel_ty
485           return (sel : concat vars, mk_result (Var sel) exprs)
486       where
487         [psum_con] = tyConDataCons psum_tc
488         mk_result sel exprs = wrapFamInstBody psum_tc tys
489                             $ mkConApp psum_con
490                             $ map Type tys ++ (sel : exprs)
491
492     to_con (ConRepr _ r) = to_prod r
493
494     to_prod EmptyProd = do
495                           pvoid <- builtin pvoidVar
496                           return ([], Var pvoid)
497     to_prod (UnaryProd r)
498       = do
499           pty  <- mkPDataType (compOrigType r)
500           var  <- newLocalVar (fsLit "x") pty
501           expr <- to_comp (Var var) r
502           return ([var], expr)
503
504     to_prod (Prod { repr_ptup_tc  = ptup_tc
505                   , repr_comp_tys = tys
506                   , repr_comps    = comps })
507       = do
508           ptys <- mapM (mkPDataType . compOrigType) comps
509           vars <- newLocalVars (fsLit "x") ptys
510           es   <- zipWithM to_comp (map Var vars) comps
511           return (vars, mk_result es)
512       where
513         [ptup_con] = tyConDataCons ptup_tc
514         mk_result exprs = wrapFamInstBody ptup_tc tys
515                         $ mkConApp ptup_con
516                         $ map Type tys ++ exprs
517
518     to_comp expr (Keep _ _) = return expr
519
520     -- FIXME: this is bound to be wrong!
521     to_comp expr (Wrap ty)
522       = do
523           wrap_tc  <- builtin wrapTyCon
524           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
525           return $ wrapNewTypeBody pwrap_tc [ty] expr
526
527
528 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
529 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
530   = do
531       arg_ty <- mkPDataType =<< mkPReprType el_ty
532       res_ty <- mkPDataType el_ty
533       arg    <- newLocalVar (fsLit "xs") arg_ty
534
535       pdata_co <- mkBuiltinCo pdataTyCon
536       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
537           co           = mkAppCoercion pdata_co
538                        $ mkTyConApp repr_co var_tys
539
540           scrut  = mkCoerce co (Var arg)
541
542           mk_result args = wrapFamInstBody pdata_tc var_tys
543                          $ mkConApp pdata_con
544                          $ map Type var_tys ++ args
545
546       (expr, _) <- fixV $ \ ~(_, args) ->
547                      from_sum res_ty (mk_result args) scrut r
548
549       return $ Lam arg expr
550     
551       -- (args, mk) <- from_sum res_ty scrut r
552       
553       -- let result = wrapFamInstBody pdata_tc var_tys
554       --           . mkConApp pdata_dc
555       --           $ map Type var_tys ++ args
556
557       -- return $ Lam arg (mk result)
558   where
559     var_tys = mkTyVarTys $ tyConTyVars vect_tc
560     el_ty   = mkTyConApp vect_tc var_tys
561
562     [pdata_con] = tyConDataCons pdata_tc
563
564     from_sum _ res _ EmptySum = return (res, [])
565     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
566     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
567                                   , repr_sel_ty  = sel_ty
568                                   , repr_con_tys = tys
569                                   , repr_cons    = cons })
570       = do
571           sel  <- newLocalVar (fsLit "sel") sel_ty
572           ptys <- mapM mkPDataType tys
573           vars <- newLocalVars (fsLit "xs") ptys
574           (res', args) <- fold from_con res_ty res (map Var vars) cons
575           let scrut = unwrapFamInstScrut psum_tc tys expr
576               body  = mkWildCase scrut (exprType scrut) res_ty
577                       [(DataAlt psum_con, sel : vars, res')]
578           return (body, Var sel : args)
579       where
580         [psum_con] = tyConDataCons psum_tc
581
582
583     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
584
585     from_prod _ res _ EmptyProd = return (res, [])
586     from_prod res_ty res expr (UnaryProd r)
587       = from_comp res_ty res expr r
588     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
589                                     , repr_comp_tys = tys
590                                     , repr_comps    = comps })
591       = do
592           ptys <- mapM mkPDataType tys
593           vars <- newLocalVars (fsLit "ys") ptys
594           (res', args) <- fold from_comp res_ty res (map Var vars) comps
595           let scrut = unwrapFamInstScrut ptup_tc tys expr
596               body  = mkWildCase scrut (exprType scrut) res_ty
597                       [(DataAlt ptup_con, vars, res')]
598           return (body, args)
599       where
600         [ptup_con] = tyConDataCons ptup_tc
601
602     from_comp _ res expr (Keep _ _) = return (res, [expr])
603     from_comp _ res expr (Wrap ty)
604       = do
605           wrap_tc  <- builtin wrapTyCon
606           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
607           return (res, [unwrapNewTypeBody pwrap_tc [ty]
608                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
609
610     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
611       where
612         f' (expr, r) (res, args) = do
613                                      (res', args') <- f res_ty res expr r
614                                      return (res', args' ++ args)
615
616 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
617 buildPRDict vect_tc prepr_tc _ r
618   = do
619       dict <- sum_dict r
620       pr_co <- mkBuiltinCo prTyCon
621       let co = mkAppCoercion pr_co
622              . mkSymCoercion
623              $ mkTyConApp arg_co ty_args
624       return (mkCoerce co dict)
625   where
626     ty_args = mkTyVarTys (tyConTyVars vect_tc)
627     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
628
629     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
630     sum_dict (UnarySum r) = con_dict r
631     sum_dict (Sum { repr_sum_tc  = sum_tc
632                   , repr_con_tys = tys
633                   , repr_cons    = cons
634                   })
635       = do
636           dicts <- mapM con_dict cons
637           dfun  <- prDFunOfTyCon sum_tc
638           return $ dfun `mkTyApps` tys `mkApps` dicts
639
640     con_dict (ConRepr _ r) = prod_dict r
641
642     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
643     prod_dict (UnaryProd r) = comp_dict r
644     prod_dict (Prod { repr_tup_tc   = tup_tc
645                     , repr_comp_tys = tys
646                     , repr_comps    = comps })
647       = do
648           dicts <- mapM comp_dict comps
649           dfun <- prDFunOfTyCon tup_tc
650           return $ dfun `mkTyApps` tys `mkApps` dicts
651
652     comp_dict (Keep _ pr) = return pr
653     comp_dict (Wrap ty)   = wrapPR ty
654
655
656 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
657 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
658   do
659     name' <- cloneName mkPDataTyConOcc orig_name
660     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
661     pdata <- builtin pdataTyCon
662
663     liftDs $ buildAlgTyCon name'
664                            tyvars
665                            []          -- no stupid theta
666                            rhs
667                            rec_flag    -- FIXME: is this ok?
668                            False       -- FIXME: no generics
669                            False       -- not GADT syntax
670                            (Just $ mk_fam_inst pdata vect_tc)
671   where
672     orig_name = tyConName orig_tc
673     tyvars = tyConTyVars vect_tc
674     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
675
676
677 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
678 buildPDataTyConRhs orig_name vect_tc repr_tc repr
679   = do
680       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
681       return $ DataTyCon { data_cons = [data_con], is_enum = False }
682
683 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
684 buildPDataDataCon orig_name vect_tc repr_tc repr
685   = do
686       dc_name  <- cloneName mkPDataDataConOcc orig_name
687       comp_tys <- sum_tys repr
688
689       liftDs $ buildDataCon dc_name
690                             False                  -- not infix
691                             (map (const NotMarkedStrict) comp_tys)
692                             []                     -- no field labels
693                             tvs
694                             []                     -- no existentials
695                             []                     -- no eq spec
696                             []                     -- no context
697                             comp_tys
698                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
699                             repr_tc
700   where
701     tvs   = tyConTyVars vect_tc
702
703     sum_tys EmptySum = return []
704     sum_tys (UnarySum r) = con_tys r
705     sum_tys (Sum { repr_sel_ty = sel_ty
706                  , repr_cons   = cons })
707       = liftM (sel_ty :) (concatMapM con_tys cons)
708
709     con_tys (ConRepr _ r) = prod_tys r
710
711     prod_tys EmptyProd = return []
712     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
713     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
714
715     comp_ty r = mkPDataType (compOrigType r)
716
717
718 mkPADFun :: TyCon -> VM Var
719 mkPADFun vect_tc
720   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
721
722 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> SumRepr 
723                    -> VM [(Var, CoreExpr)]
724 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun repr
725   = do
726       vectDataConWorkers orig_tc vect_tc pdata_tc
727       dict <- buildPADict vect_tc prepr_tc pdata_tc repr
728       binds <- takeHoisted
729       return $ (dfun, dict) : binds
730
731 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
732 vectDataConWorkers orig_tc vect_tc arr_tc
733   = do
734       bs <- sequence
735           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
736           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
737                                  rep_tys
738                                  (inits rep_tys)
739                                  (tail $ tails rep_tys)
740       mapM_ (uncurry hoistBinding) bs
741   where
742     tyvars   = tyConTyVars vect_tc
743     var_tys  = mkTyVarTys tyvars
744     ty_args  = map Type var_tys
745     res_ty   = mkTyConApp vect_tc var_tys
746
747     cons     = tyConDataCons vect_tc
748     arity    = length cons
749     [arr_dc] = tyConDataCons arr_tc
750
751     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
752
753
754     mk_data_con con tys pre post
755       = liftM2 (,) (vect_data_con con)
756                    (lift_data_con tys pre post (mkDataConTag con))
757
758     sel_replicate len tag
759       | arity > 1 = do
760                       rep <- builtin (selReplicate arity)
761                       return [rep `mkApps` [len, tag]]
762
763       | otherwise = return []
764
765     vect_data_con con = return $ mkConApp con ty_args
766     lift_data_con tys pre_tys post_tys tag
767       = do
768           len  <- builtin liftingContext
769           args <- mapM (newLocalVar (fsLit "xs"))
770                   =<< mapM mkPDataType tys
771
772           sel  <- sel_replicate (Var len) tag
773
774           pre   <- mapM emptyPD (concat pre_tys)
775           post  <- mapM emptyPD (concat post_tys)
776
777           return . mkLams (len : args)
778                  . wrapFamInstBody arr_tc var_tys
779                  . mkConApp arr_dc
780                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
781
782     def_worker data_con arg_tys mk_body
783       = do
784           body <- closedV
785                 . inBind orig_worker
786                 . polyAbstract tyvars $ \abstract ->
787                   liftM (abstract . vectorised)
788                 $ buildClosures tyvars [] arg_tys res_ty mk_body
789
790           vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
791           defGlobalVar orig_worker vect_worker
792           return (vect_worker, body)
793       where
794         orig_worker = dataConWorkId data_con
795
796 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
797 buildPADict vect_tc prepr_tc arr_tc repr
798   = polyAbstract tvs $ \abstract ->
799     do
800       meth_binds <- mapM mk_method paMethods
801       let meth_exprs = map (Var . fst) meth_binds
802
803       pa_dc <- builtin paDataCon
804       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
805           body = Let (Rec meth_binds) dict
806       return . mkInlineMe $ abstract body
807   where
808     tvs = tyConTyVars arr_tc
809     arg_tys = mkTyVarTys tvs
810
811     mk_method (name, build)
812       = localV
813       $ do
814           body <- build vect_tc prepr_tc arr_tc repr
815           var  <- newLocalVar name (exprType body)
816           return (var, mkInlineMe body)
817
818 paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
819 paMethods = [(fsLit "dictPRepr",    buildPRDict),
820              (fsLit "toPRepr",      buildToPRepr),
821              (fsLit "fromPRepr",    buildFromPRepr),
822              (fsLit "toArrPRepr",   buildToArrPRepr),
823              (fsLit "fromArrPRepr", buildFromArrPRepr)]
824
825 -- | Split the given tycons into two sets depending on whether they have to be
826 -- converted (first list) or not (second list). The first argument contains
827 -- information about the conversion status of external tycons:
828 --
829 --   * tycons which have converted versions are mapped to True
830 --   * tycons which are not changed by vectorisation are mapped to False
831 --   * tycons which can't be converted are not elements of the map
832 --
833 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
834 classifyTyCons = classify [] []
835   where
836     classify conv keep _  [] = (conv, keep)
837     classify conv keep cs ((tcs, ds) : rs)
838       | can_convert && must_convert
839         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
840       | can_convert
841         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
842       | otherwise
843         = classify conv keep cs rs
844       where
845         refs = ds `delListFromUniqSet` tcs
846
847         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
848         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
849
850         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
851
852 -- | Compute mutually recursive groups of tycons in topological order
853 --
854 tyConGroups :: [TyCon] -> [TyConGroup]
855 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
856   where
857     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
858                                 , let ds = tyConsOfTyCon tc]
859
860     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
861     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
862       where
863         (tcs, dss) = unzip els
864
865 tyConsOfTyCon :: TyCon -> UniqSet TyCon
866 tyConsOfTyCon
867   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
868
869 tyConsOfType :: Type -> UniqSet TyCon
870 tyConsOfType ty
871   | Just ty' <- coreView ty    = tyConsOfType ty'
872 tyConsOfType (TyVarTy _)       = emptyUniqSet
873 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
874   where
875     extend | isUnLiftedTyCon tc
876            || isTupleTyCon   tc = id
877
878            | otherwise          = (`addOneToUniqSet` tc)
879
880 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
881 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
882                                  `addOneToUniqSet` funTyCon
883 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
884 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
885
886 tyConsOfTypes :: [Type] -> UniqSet TyCon
887 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
888
889
890 -- ----------------------------------------------------------------------------
891 -- Conversions
892
893 fromVect :: Type -> CoreExpr -> VM CoreExpr
894 fromVect ty expr | Just ty' <- coreView ty = fromVect ty' expr
895 fromVect (FunTy arg_ty res_ty) expr
896   = do
897       arg     <- newLocalVar (fsLit "x") arg_ty
898       varg    <- toVect arg_ty (Var arg)
899       varg_ty <- vectType arg_ty
900       vres_ty <- vectType res_ty
901       apply   <- builtin applyVar
902       body    <- fromVect res_ty
903                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
904       return $ Lam arg body
905 fromVect ty expr
906   = identityConv ty >> return expr
907
908 toVect :: Type -> CoreExpr -> VM CoreExpr
909 toVect ty expr = identityConv ty >> return expr
910
911 identityConv :: Type -> VM ()
912 identityConv ty | Just ty' <- coreView ty = identityConv ty'
913 identityConv (TyConApp tycon tys)
914   = do
915       mapM_ identityConv tys
916       identityConvTyCon tycon
917 identityConv _ = noV
918
919 identityConvTyCon :: TyCon -> VM ()
920 identityConvTyCon tc
921   | isBoxedTupleTyCon tc = return ()
922   | isUnLiftedTyCon tc   = return ()
923   | otherwise            = do
924                              tc' <- maybeV (lookupTyCon tc)
925                              if tc == tc' then return () else noV
926