Break out vectorisation of TyConDecls into own module
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.Type
15 import Vectorise.Type.TyConDecl
16
17 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
18 import BasicTypes
19 import CoreSyn
20 import CoreUtils
21 import CoreUnfold
22 import MkCore            ( mkWildCase )
23 import BuildTyCl
24 import DataCon
25 import TyCon
26 import Type
27 import TypeRep
28 import Coercion
29 import FamInstEnv        ( FamInst, mkLocalFamInst )
30 import OccName
31 import Id
32 import MkId
33 import Var
34 import Name              ( Name, getOccName )
35 import NameEnv
36
37 import Unique
38 import UniqFM
39 import UniqSet
40 import Util
41 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
42
43 import Outputable
44 import FastString
45
46 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
47 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
48 import Data.List
49
50 debug           = False
51 dtrace s x      = if debug then pprTrace "VectType" s x else x
52
53 -- ----------------------------------------------------------------------------
54 -- Types
55
56
57 -- ----------------------------------------------------------------------------
58 -- Type definitions
59
60 type TyConGroup = ([TyCon], UniqSet TyCon)
61
62 -- | Vectorise a type environment.
63 --   The type environment contains all the type things defined in a module.
64 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
65 vectTypeEnv env
66  = dtrace (ppr env)
67  $ do
68       cs <- readGEnv $ mk_map . global_tycons
69
70       -- Split the list of TyCons into the ones we have to vectorise vs the
71       -- ones we can pass through unchanged. We also pass through algebraic 
72       -- types that use non Haskell98 features, as we don't handle those.
73       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
74           keep_dcs             = concatMap tyConDataCons keep_tcs
75
76       zipWithM_ defTyCon   keep_tcs keep_tcs
77       zipWithM_ defDataCon keep_dcs keep_dcs
78
79       new_tcs <- vectTyConDecls conv_tcs
80
81       let orig_tcs = keep_tcs ++ conv_tcs
82
83       -- We don't need to make new representation types for dictionary
84       -- constructors. The constructors are always fully applied, and we don't 
85       -- need to lift them to arrays as a dictionary of a particular type
86       -- always has the same value.
87       let vect_tcs = filter (not . isClassTyCon) 
88                    $ keep_tcs ++ new_tcs
89
90       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
91         do
92           defTyConPAs (zipLazy vect_tcs dfuns')
93           reprs     <- mapM tyConRepr vect_tcs
94           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
95           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
96
97           dfuns     <- sequence 
98                     $  zipWith5 buildTyConBindings
99                                orig_tcs
100                                vect_tcs
101                                repr_tcs
102                                pdata_tcs
103                                reprs
104
105           binds     <- takeHoisted
106           return (dfuns, binds, repr_tcs ++ pdata_tcs)
107
108       let all_new_tcs = new_tcs ++ inst_tcs
109
110       let new_env = extendTypeEnvList env
111                        (map ATyCon all_new_tcs
112                         ++ [ADataCon dc | tc <- all_new_tcs
113                                         , dc <- tyConDataCons tc])
114
115       return (new_env, map mkLocalFamInst inst_tcs, binds)
116   where
117     tycons = typeEnvTyCons env
118     groups = tyConGroups tycons
119
120     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
121
122
123 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
124 mk_fam_inst fam_tc arg_tc
125   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
126
127
128 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
129 buildPReprTyCon orig_tc vect_tc repr
130   = do
131       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
132       -- rhs_ty   <- buildPReprType vect_tc
133       rhs_ty   <- sumReprType repr
134       prepr_tc <- builtin preprTyCon
135       liftDs $ buildSynTyCon name
136                              tyvars
137                              (SynonymTyCon rhs_ty)
138                              (typeKind rhs_ty)
139                              (Just $ mk_fam_inst prepr_tc vect_tc)
140   where
141     tyvars = tyConTyVars vect_tc
142
143 data CompRepr = Keep Type
144                      CoreExpr     -- PR dictionary for the type
145               | Wrap Type
146
147 data ProdRepr = EmptyProd
148               | UnaryProd CompRepr
149               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
150                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
151                      , repr_comp_tys :: [Type] -- representation types of
152                      , repr_comps    :: [CompRepr]          -- components
153                      }
154 data ConRepr  = ConRepr DataCon ProdRepr
155
156 data SumRepr  = EmptySum
157               | UnarySum ConRepr
158               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
159                      , repr_psum_tc  :: TyCon  -- PData representation tycon
160                      , repr_sel_ty   :: Type   -- type of selector
161                      , repr_con_tys :: [Type]  -- representation types of
162                      , repr_cons     :: [ConRepr]           -- components
163                      }
164
165 tyConRepr :: TyCon -> VM SumRepr
166 tyConRepr tc = sum_repr (tyConDataCons tc)
167   where
168     sum_repr []    = return EmptySum
169     sum_repr [con] = liftM UnarySum (con_repr con)
170     sum_repr cons  = do
171                        rs     <- mapM con_repr cons
172                        sum_tc <- builtin (sumTyCon arity)
173                        tys    <- mapM conReprType rs
174                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
175                        sel_ty <- builtin (selTy arity)
176                        return $ Sum { repr_sum_tc  = sum_tc
177                                     , repr_psum_tc = psum_tc
178                                     , repr_sel_ty  = sel_ty
179                                     , repr_con_tys = tys
180                                     , repr_cons    = rs
181                                     }
182       where
183         arity = length cons
184
185     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
186
187     prod_repr []   = return EmptyProd
188     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
189     prod_repr tys  = do
190                        rs <- mapM comp_repr tys
191                        tup_tc <- builtin (prodTyCon arity)
192                        tys'    <- mapM compReprType rs
193                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
194                        return $ Prod { repr_tup_tc   = tup_tc
195                                      , repr_ptup_tc  = ptup_tc
196                                      , repr_comp_tys = tys'
197                                      , repr_comps    = rs
198                                      }
199       where
200         arity = length tys
201     
202     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
203                    `orElseV` return (Wrap ty)
204
205 sumReprType :: SumRepr -> VM Type
206 sumReprType EmptySum = voidType
207 sumReprType (UnarySum r) = conReprType r
208 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
209   = return $ mkTyConApp sum_tc tys
210
211 conReprType :: ConRepr -> VM Type
212 conReprType (ConRepr _ r) = prodReprType r
213
214 prodReprType :: ProdRepr -> VM Type
215 prodReprType EmptyProd = voidType
216 prodReprType (UnaryProd r) = compReprType r
217 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
218   = return $ mkTyConApp tup_tc tys
219
220 compReprType :: CompRepr -> VM Type
221 compReprType (Keep ty _) = return ty
222 compReprType (Wrap ty) = do
223                              wrap_tc <- builtin wrapTyCon
224                              return $ mkTyConApp wrap_tc [ty]
225
226 compOrigType :: CompRepr -> Type
227 compOrigType (Keep ty _) = ty
228 compOrigType (Wrap ty) = ty
229
230 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
231 buildToPRepr vect_tc repr_tc _ repr
232   = do
233       let arg_ty = mkTyConApp vect_tc ty_args
234       res_ty <- mkPReprType arg_ty
235       arg    <- newLocalVar (fsLit "x") arg_ty
236       result <- to_sum (Var arg) arg_ty res_ty repr
237       return $ Lam arg result
238   where
239     ty_args = mkTyVarTys (tyConTyVars vect_tc)
240
241     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
242
243     to_sum _ _ _ EmptySum
244       = do
245           void <- builtin voidVar
246           return $ wrap_repr_inst $ Var void
247
248     to_sum arg arg_ty res_ty (UnarySum r)
249       = do
250           (pat, vars, body) <- con_alt r
251           return $ mkWildCase arg arg_ty res_ty
252                    [(pat, vars, wrap_repr_inst body)]
253
254     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
255                                   , repr_con_tys = tys
256                                   , repr_cons    =  cons })
257       = do
258           alts <- mapM con_alt cons
259           let alts' = [(pat, vars, wrap_repr_inst
260                                    $ mkConApp sum_con (map Type tys ++ [body]))
261                         | ((pat, vars, body), sum_con)
262                             <- zip alts (tyConDataCons sum_tc)]
263           return $ mkWildCase arg arg_ty res_ty alts'
264
265     con_alt (ConRepr con r)
266       = do
267           (vars, body) <- to_prod r
268           return (DataAlt con, vars, body)
269
270     to_prod EmptyProd
271       = do
272           void <- builtin voidVar
273           return ([], Var void)
274
275     to_prod (UnaryProd comp)
276       = do
277           var  <- newLocalVar (fsLit "x") (compOrigType comp)
278           body <- to_comp (Var var) comp
279           return ([var], body)
280
281     to_prod(Prod { repr_tup_tc   = tup_tc
282                  , repr_comp_tys = tys
283                  , repr_comps    = comps })
284       = do
285           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
286           exprs <- zipWithM to_comp (map Var vars) comps
287           return (vars, mkConApp tup_con (map Type tys ++ exprs))
288       where
289         [tup_con] = tyConDataCons tup_tc
290
291     to_comp expr (Keep _ _) = return expr
292     to_comp expr (Wrap ty)  = do
293                                 wrap_tc <- builtin wrapTyCon
294                                 return $ wrapNewTypeBody wrap_tc [ty] expr
295
296
297 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
298 buildFromPRepr vect_tc repr_tc _ repr
299   = do
300       arg_ty <- mkPReprType res_ty
301       arg <- newLocalVar (fsLit "x") arg_ty
302
303       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
304                          repr
305       return $ Lam arg result
306   where
307     ty_args = mkTyVarTys (tyConTyVars vect_tc)
308     res_ty  = mkTyConApp vect_tc ty_args
309
310     from_sum _ EmptySum
311       = do
312           dummy <- builtin fromVoidVar
313           return $ Var dummy `App` Type res_ty
314
315     from_sum expr (UnarySum r) = from_con expr r
316     from_sum expr (Sum { repr_sum_tc  = sum_tc
317                        , repr_con_tys = tys
318                        , repr_cons    = cons })
319       = do
320           vars  <- newLocalVars (fsLit "x") tys
321           es    <- zipWithM from_con (map Var vars) cons
322           return $ mkWildCase expr (exprType expr) res_ty
323                    [(DataAlt con, [var], e)
324                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
325
326     from_con expr (ConRepr con r)
327       = from_prod expr (mkConApp con $ map Type ty_args) r
328
329     from_prod _ con EmptyProd = return con
330     from_prod expr con (UnaryProd r)
331       = do
332           e <- from_comp expr r
333           return $ con `App` e
334      
335     from_prod expr con (Prod { repr_tup_tc   = tup_tc
336                              , repr_comp_tys = tys
337                              , repr_comps    = comps
338                              })
339       = do
340           vars <- newLocalVars (fsLit "y") tys
341           es   <- zipWithM from_comp (map Var vars) comps
342           return $ mkWildCase expr (exprType expr) res_ty
343                    [(DataAlt tup_con, vars, con `mkApps` es)]
344       where
345         [tup_con] = tyConDataCons tup_tc  
346
347     from_comp expr (Keep _ _) = return expr
348     from_comp expr (Wrap ty)
349       = do
350           wrap <- builtin wrapTyCon
351           return $ unwrapNewTypeBody wrap [ty] expr
352
353
354 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
355 buildToArrPRepr vect_tc prepr_tc pdata_tc r
356   = do
357       arg_ty <- mkPDataType el_ty
358       res_ty <- mkPDataType =<< mkPReprType el_ty
359       arg    <- newLocalVar (fsLit "xs") arg_ty
360
361       pdata_co <- mkBuiltinCo pdataTyCon
362       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
363           co           = mkAppCoercion pdata_co
364                        . mkSymCoercion
365                        $ mkTyConApp repr_co ty_args
366
367           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
368
369       (vars, result) <- to_sum r
370
371       return . Lam arg
372              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
373                [(DataAlt pdata_dc, vars, mkCoerce co result)]
374   where
375     ty_args = mkTyVarTys $ tyConTyVars vect_tc
376     el_ty   = mkTyConApp vect_tc ty_args
377
378     [pdata_dc] = tyConDataCons pdata_tc
379
380
381     to_sum EmptySum = do
382                         pvoid <- builtin pvoidVar
383                         return ([], Var pvoid)
384     to_sum (UnarySum r) = to_con r
385     to_sum (Sum { repr_psum_tc = psum_tc
386                 , repr_sel_ty  = sel_ty
387                 , repr_con_tys = tys
388                 , repr_cons    = cons
389                 })
390       = do
391           (vars, exprs) <- mapAndUnzipM to_con cons
392           sel <- newLocalVar (fsLit "sel") sel_ty
393           return (sel : concat vars, mk_result (Var sel) exprs)
394       where
395         [psum_con] = tyConDataCons psum_tc
396         mk_result sel exprs = wrapFamInstBody psum_tc tys
397                             $ mkConApp psum_con
398                             $ map Type tys ++ (sel : exprs)
399
400     to_con (ConRepr _ r) = to_prod r
401
402     to_prod EmptyProd = do
403                           pvoid <- builtin pvoidVar
404                           return ([], Var pvoid)
405     to_prod (UnaryProd r)
406       = do
407           pty  <- mkPDataType (compOrigType r)
408           var  <- newLocalVar (fsLit "x") pty
409           expr <- to_comp (Var var) r
410           return ([var], expr)
411
412     to_prod (Prod { repr_ptup_tc  = ptup_tc
413                   , repr_comp_tys = tys
414                   , repr_comps    = comps })
415       = do
416           ptys <- mapM (mkPDataType . compOrigType) comps
417           vars <- newLocalVars (fsLit "x") ptys
418           es   <- zipWithM to_comp (map Var vars) comps
419           return (vars, mk_result es)
420       where
421         [ptup_con] = tyConDataCons ptup_tc
422         mk_result exprs = wrapFamInstBody ptup_tc tys
423                         $ mkConApp ptup_con
424                         $ map Type tys ++ exprs
425
426     to_comp expr (Keep _ _) = return expr
427
428     -- FIXME: this is bound to be wrong!
429     to_comp expr (Wrap ty)
430       = do
431           wrap_tc  <- builtin wrapTyCon
432           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
433           return $ wrapNewTypeBody pwrap_tc [ty] expr
434
435
436 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
437 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
438   = do
439       arg_ty <- mkPDataType =<< mkPReprType el_ty
440       res_ty <- mkPDataType el_ty
441       arg    <- newLocalVar (fsLit "xs") arg_ty
442
443       pdata_co <- mkBuiltinCo pdataTyCon
444       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
445           co           = mkAppCoercion pdata_co
446                        $ mkTyConApp repr_co var_tys
447
448           scrut  = mkCoerce co (Var arg)
449
450           mk_result args = wrapFamInstBody pdata_tc var_tys
451                          $ mkConApp pdata_con
452                          $ map Type var_tys ++ args
453
454       (expr, _) <- fixV $ \ ~(_, args) ->
455                      from_sum res_ty (mk_result args) scrut r
456
457       return $ Lam arg expr
458     
459       -- (args, mk) <- from_sum res_ty scrut r
460       
461       -- let result = wrapFamInstBody pdata_tc var_tys
462       --           . mkConApp pdata_dc
463       --           $ map Type var_tys ++ args
464
465       -- return $ Lam arg (mk result)
466   where
467     var_tys = mkTyVarTys $ tyConTyVars vect_tc
468     el_ty   = mkTyConApp vect_tc var_tys
469
470     [pdata_con] = tyConDataCons pdata_tc
471
472     from_sum _ res _ EmptySum = return (res, [])
473     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
474     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
475                                   , repr_sel_ty  = sel_ty
476                                   , repr_con_tys = tys
477                                   , repr_cons    = cons })
478       = do
479           sel  <- newLocalVar (fsLit "sel") sel_ty
480           ptys <- mapM mkPDataType tys
481           vars <- newLocalVars (fsLit "xs") ptys
482           (res', args) <- fold from_con res_ty res (map Var vars) cons
483           let scrut = unwrapFamInstScrut psum_tc tys expr
484               body  = mkWildCase scrut (exprType scrut) res_ty
485                       [(DataAlt psum_con, sel : vars, res')]
486           return (body, Var sel : args)
487       where
488         [psum_con] = tyConDataCons psum_tc
489
490
491     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
492
493     from_prod _ res _ EmptyProd = return (res, [])
494     from_prod res_ty res expr (UnaryProd r)
495       = from_comp res_ty res expr r
496     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
497                                     , repr_comp_tys = tys
498                                     , repr_comps    = comps })
499       = do
500           ptys <- mapM mkPDataType tys
501           vars <- newLocalVars (fsLit "ys") ptys
502           (res', args) <- fold from_comp res_ty res (map Var vars) comps
503           let scrut = unwrapFamInstScrut ptup_tc tys expr
504               body  = mkWildCase scrut (exprType scrut) res_ty
505                       [(DataAlt ptup_con, vars, res')]
506           return (body, args)
507       where
508         [ptup_con] = tyConDataCons ptup_tc
509
510     from_comp _ res expr (Keep _ _) = return (res, [expr])
511     from_comp _ res expr (Wrap ty)
512       = do
513           wrap_tc  <- builtin wrapTyCon
514           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
515           return (res, [unwrapNewTypeBody pwrap_tc [ty]
516                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
517
518     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
519       where
520         f' (expr, r) (res, args) = do
521                                      (res', args') <- f res_ty res expr r
522                                      return (res', args' ++ args)
523
524 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
525 buildPRDict vect_tc prepr_tc _ r
526   = do
527       dict <- sum_dict r
528       pr_co <- mkBuiltinCo prTyCon
529       let co = mkAppCoercion pr_co
530              . mkSymCoercion
531              $ mkTyConApp arg_co ty_args
532       return (mkCoerce co dict)
533   where
534     ty_args = mkTyVarTys (tyConTyVars vect_tc)
535     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
536
537     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
538     sum_dict (UnarySum r) = con_dict r
539     sum_dict (Sum { repr_sum_tc  = sum_tc
540                   , repr_con_tys = tys
541                   , repr_cons    = cons
542                   })
543       = do
544           dicts <- mapM con_dict cons
545           dfun  <- prDFunOfTyCon sum_tc
546           return $ dfun `mkTyApps` tys `mkApps` dicts
547
548     con_dict (ConRepr _ r) = prod_dict r
549
550     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
551     prod_dict (UnaryProd r) = comp_dict r
552     prod_dict (Prod { repr_tup_tc   = tup_tc
553                     , repr_comp_tys = tys
554                     , repr_comps    = comps })
555       = do
556           dicts <- mapM comp_dict comps
557           dfun <- prDFunOfTyCon tup_tc
558           return $ dfun `mkTyApps` tys `mkApps` dicts
559
560     comp_dict (Keep _ pr) = return pr
561     comp_dict (Wrap ty)   = wrapPR ty
562
563
564 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
565 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
566   do
567     name' <- cloneName mkPDataTyConOcc orig_name
568     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
569     pdata <- builtin pdataTyCon
570
571     liftDs $ buildAlgTyCon name'
572                            tyvars
573                            []          -- no stupid theta
574                            rhs
575                            rec_flag    -- FIXME: is this ok?
576                            False       -- FIXME: no generics
577                            False       -- not GADT syntax
578                            (Just $ mk_fam_inst pdata vect_tc)
579   where
580     orig_name = tyConName orig_tc
581     tyvars = tyConTyVars vect_tc
582     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
583
584
585 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
586 buildPDataTyConRhs orig_name vect_tc repr_tc repr
587   = do
588       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
589       return $ DataTyCon { data_cons = [data_con], is_enum = False }
590
591 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
592 buildPDataDataCon orig_name vect_tc repr_tc repr
593   = do
594       dc_name  <- cloneName mkPDataDataConOcc orig_name
595       comp_tys <- sum_tys repr
596
597       liftDs $ buildDataCon dc_name
598                             False                  -- not infix
599                             (map (const HsNoBang) comp_tys)
600                             []                     -- no field labels
601                             tvs
602                             []                     -- no existentials
603                             []                     -- no eq spec
604                             []                     -- no context
605                             comp_tys
606                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
607                             repr_tc
608   where
609     tvs   = tyConTyVars vect_tc
610
611     sum_tys EmptySum = return []
612     sum_tys (UnarySum r) = con_tys r
613     sum_tys (Sum { repr_sel_ty = sel_ty
614                  , repr_cons   = cons })
615       = liftM (sel_ty :) (concatMapM con_tys cons)
616
617     con_tys (ConRepr _ r) = prod_tys r
618
619     prod_tys EmptyProd = return []
620     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
621     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
622
623     comp_ty r = mkPDataType (compOrigType r)
624
625
626 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
627                    -> VM Var
628 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
629   = do
630       vectDataConWorkers orig_tc vect_tc pdata_tc
631       buildPADict vect_tc prepr_tc pdata_tc repr
632
633 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
634 vectDataConWorkers orig_tc vect_tc arr_tc
635   = do
636       bs <- sequence
637           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
638           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
639                                  rep_tys
640                                  (inits rep_tys)
641                                  (tail $ tails rep_tys)
642       mapM_ (uncurry hoistBinding) bs
643   where
644     tyvars   = tyConTyVars vect_tc
645     var_tys  = mkTyVarTys tyvars
646     ty_args  = map Type var_tys
647     res_ty   = mkTyConApp vect_tc var_tys
648
649     cons     = tyConDataCons vect_tc
650     arity    = length cons
651     [arr_dc] = tyConDataCons arr_tc
652
653     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
654
655
656     mk_data_con con tys pre post
657       = liftM2 (,) (vect_data_con con)
658                    (lift_data_con tys pre post (mkDataConTag con))
659
660     sel_replicate len tag
661       | arity > 1 = do
662                       rep <- builtin (selReplicate arity)
663                       return [rep `mkApps` [len, tag]]
664
665       | otherwise = return []
666
667     vect_data_con con = return $ mkConApp con ty_args
668     lift_data_con tys pre_tys post_tys tag
669       = do
670           len  <- builtin liftingContext
671           args <- mapM (newLocalVar (fsLit "xs"))
672                   =<< mapM mkPDataType tys
673
674           sel  <- sel_replicate (Var len) tag
675
676           pre   <- mapM emptyPD (concat pre_tys)
677           post  <- mapM emptyPD (concat post_tys)
678
679           return . mkLams (len : args)
680                  . wrapFamInstBody arr_tc var_tys
681                  . mkConApp arr_dc
682                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
683
684     def_worker data_con arg_tys mk_body
685       = do
686           arity <- polyArity tyvars
687           body <- closedV
688                 . inBind orig_worker
689                 . polyAbstract tyvars $ \args ->
690                   liftM (mkLams (tyvars ++ args) . vectorised)
691                 $ buildClosures tyvars [] arg_tys res_ty mk_body
692
693           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
694           let vect_worker = raw_worker `setIdUnfolding`
695                               mkInlineRule body (Just arity)
696           defGlobalVar orig_worker vect_worker
697           return (vect_worker, body)
698       where
699         orig_worker = dataConWorkId data_con
700
701 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
702 buildPADict vect_tc prepr_tc arr_tc repr
703   = polyAbstract tvs $ \args ->
704     do
705       method_ids <- mapM (method args) paMethods
706
707       pa_tc  <- builtin paTyCon
708       pa_dc  <- builtin paDataCon
709       let dict = mkLams (tvs ++ args)
710                $ mkConApp pa_dc
711                $ Type inst_ty : map (method_call args) method_ids
712
713           dfun_ty = mkForAllTys tvs
714                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
715
716       raw_dfun <- newExportedVar dfun_name dfun_ty
717       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
718                           `setInlinePragma` dfunInlinePragma
719
720       hoistBinding dfun dict
721       return dfun
722   where
723     tvs = tyConTyVars vect_tc
724     arg_tys = mkTyVarTys tvs
725     inst_ty = mkTyConApp vect_tc arg_tys
726
727     dfun_name = mkPADFunOcc (getOccName vect_tc)
728
729     method args (name, build)
730       = localV
731       $ do
732           expr <- build vect_tc prepr_tc arr_tc repr
733           let body = mkLams (tvs ++ args) expr
734           raw_var <- newExportedVar (method_name name) (exprType body)
735           let var = raw_var
736                       `setIdUnfolding` mkInlineRule body (Just (length args))
737                       `setInlinePragma` alwaysInlinePragma
738           hoistBinding var body
739           return var
740
741     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
742
743     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
744
745
746 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
747 paMethods = [("dictPRepr",    buildPRDict),
748              ("toPRepr",      buildToPRepr),
749              ("fromPRepr",    buildFromPRepr),
750              ("toArrPRepr",   buildToArrPRepr),
751              ("fromArrPRepr", buildFromArrPRepr)]
752
753
754 -- | Split the given tycons into two sets depending on whether they have to be
755 --   converted (first list) or not (second list). The first argument contains
756 --   information about the conversion status of external tycons:
757 --
758 --   * tycons which have converted versions are mapped to True
759 --   * tycons which are not changed by vectorisation are mapped to False
760 --   * tycons which can't be converted are not elements of the map
761 --
762 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
763 classifyTyCons = classify [] []
764   where
765     classify conv keep _  [] = (conv, keep)
766     classify conv keep cs ((tcs, ds) : rs)
767       | can_convert && must_convert
768         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
769       | can_convert
770         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
771       | otherwise
772         = classify conv keep cs rs
773       where
774         refs = ds `delListFromUniqSet` tcs
775
776         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
777         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
778
779         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
780
781 -- | Compute mutually recursive groups of tycons in topological order
782 --
783 tyConGroups :: [TyCon] -> [TyConGroup]
784 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
785   where
786     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
787                                 , let ds = tyConsOfTyCon tc]
788
789     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
790     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
791       where
792         (tcs, dss) = unzip els
793
794 tyConsOfTyCon :: TyCon -> UniqSet TyCon
795 tyConsOfTyCon
796   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
797
798 tyConsOfType :: Type -> UniqSet TyCon
799 tyConsOfType ty
800   | Just ty' <- coreView ty    = tyConsOfType ty'
801 tyConsOfType (TyVarTy _)       = emptyUniqSet
802 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
803   where
804     extend | isUnLiftedTyCon tc
805            || isTupleTyCon   tc = id
806
807            | otherwise          = (`addOneToUniqSet` tc)
808
809 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
810 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
811                                  `addOneToUniqSet` funTyCon
812 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
813 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
814
815 tyConsOfTypes :: [Type] -> UniqSet TyCon
816 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
817
818
819 -- ----------------------------------------------------------------------------
820 -- Conversions
821
822 -- | Build an expression that calls the vectorised version of some 
823 --   function from a `Closure`.
824 --
825 --   For example
826 --   @   
827 --      \(x :: Double) -> 
828 --      \(y :: Double) -> 
829 --      ($v_foo $: x) $: y
830 --   @
831 --
832 --   We use the type of the original binding to work out how many
833 --   outer lambdas to add.
834 --
835 fromVect 
836         :: Type         -- ^ The type of the original binding.
837         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
838         -> VM CoreExpr
839         
840 -- Convert the type to the core view if it isn't already.
841 fromVect ty expr 
842         | Just ty' <- coreView ty 
843         = fromVect ty' expr
844
845 -- For each function constructor in the original type we add an outer 
846 -- lambda to bind the parameter variable, and an inner application of it.
847 fromVect (FunTy arg_ty res_ty) expr
848   = do
849       arg     <- newLocalVar (fsLit "x") arg_ty
850       varg    <- toVect arg_ty (Var arg)
851       varg_ty <- vectType arg_ty
852       vres_ty <- vectType res_ty
853       apply   <- builtin applyVar
854       body    <- fromVect res_ty
855                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
856       return $ Lam arg body
857
858 -- If the type isn't a function then it's time to call on the closure.
859 fromVect ty expr
860   = identityConv ty >> return expr
861
862
863 toVect :: Type -> CoreExpr -> VM CoreExpr
864 toVect ty expr = identityConv ty >> return expr
865
866
867 identityConv :: Type -> VM ()
868 identityConv ty | Just ty' <- coreView ty = identityConv ty'
869 identityConv (TyConApp tycon tys)
870   = do
871       mapM_ identityConv tys
872       identityConvTyCon tycon
873 identityConv _ = noV
874
875 identityConvTyCon :: TyCon -> VM ()
876 identityConvTyCon tc
877   | isBoxedTupleTyCon tc = return ()
878   | isUnLiftedTyCon tc   = return ()
879   | otherwise            = do
880                              tc' <- maybeV (lookupTyCon tc)
881                              if tc == tc' then return () else noV
882