Break out closure utils into own module
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.Type
15 import Vectorise.Type.TyConDecl
16 import Vectorise.Type.Classify
17 import Vectorise.Utils.Closure
18
19 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
20 import BasicTypes
21 import CoreSyn
22 import CoreUtils
23 import CoreUnfold
24 import MkCore            ( mkWildCase )
25 import BuildTyCl
26 import DataCon
27 import TyCon
28 import Type
29 import TypeRep
30 import Coercion
31 import FamInstEnv        ( FamInst, mkLocalFamInst )
32 import OccName
33 import Id
34 import MkId
35 import Var
36 import Name              ( Name, getOccName )
37 import NameEnv
38
39 import Unique
40 import UniqFM
41 import Util
42
43 import Outputable
44 import FastString
45
46 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
47 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
48 import Data.List
49
50 debug           = False
51 dtrace s x      = if debug then pprTrace "VectType" s x else x
52
53
54 -- ----------------------------------------------------------------------------
55 -- Type definitions
56
57
58 -- | Vectorise a type environment.
59 --   The type environment contains all the type things defined in a module.
60 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
61 vectTypeEnv env
62  = dtrace (ppr env)
63  $ do
64       cs <- readGEnv $ mk_map . global_tycons
65
66       -- Split the list of TyCons into the ones we have to vectorise vs the
67       -- ones we can pass through unchanged. We also pass through algebraic 
68       -- types that use non Haskell98 features, as we don't handle those.
69       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
70           keep_dcs             = concatMap tyConDataCons keep_tcs
71
72       zipWithM_ defTyCon   keep_tcs keep_tcs
73       zipWithM_ defDataCon keep_dcs keep_dcs
74
75       new_tcs <- vectTyConDecls conv_tcs
76
77       let orig_tcs = keep_tcs ++ conv_tcs
78
79       -- We don't need to make new representation types for dictionary
80       -- constructors. The constructors are always fully applied, and we don't 
81       -- need to lift them to arrays as a dictionary of a particular type
82       -- always has the same value.
83       let vect_tcs = filter (not . isClassTyCon) 
84                    $ keep_tcs ++ new_tcs
85
86       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
87         do
88           defTyConPAs (zipLazy vect_tcs dfuns')
89           reprs     <- mapM tyConRepr vect_tcs
90           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
91           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
92
93           dfuns     <- sequence 
94                     $  zipWith5 buildTyConBindings
95                                orig_tcs
96                                vect_tcs
97                                repr_tcs
98                                pdata_tcs
99                                reprs
100
101           binds     <- takeHoisted
102           return (dfuns, binds, repr_tcs ++ pdata_tcs)
103
104       let all_new_tcs = new_tcs ++ inst_tcs
105
106       let new_env = extendTypeEnvList env
107                        (map ATyCon all_new_tcs
108                         ++ [ADataCon dc | tc <- all_new_tcs
109                                         , dc <- tyConDataCons tc])
110
111       return (new_env, map mkLocalFamInst inst_tcs, binds)
112   where
113     tycons = typeEnvTyCons env
114     groups = tyConGroups tycons
115
116     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
117
118
119 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
120 mk_fam_inst fam_tc arg_tc
121   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
122
123
124 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
125 buildPReprTyCon orig_tc vect_tc repr
126   = do
127       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
128       -- rhs_ty   <- buildPReprType vect_tc
129       rhs_ty   <- sumReprType repr
130       prepr_tc <- builtin preprTyCon
131       liftDs $ buildSynTyCon name
132                              tyvars
133                              (SynonymTyCon rhs_ty)
134                              (typeKind rhs_ty)
135                              (Just $ mk_fam_inst prepr_tc vect_tc)
136   where
137     tyvars = tyConTyVars vect_tc
138
139 data CompRepr = Keep Type
140                      CoreExpr     -- PR dictionary for the type
141               | Wrap Type
142
143 data ProdRepr = EmptyProd
144               | UnaryProd CompRepr
145               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
146                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
147                      , repr_comp_tys :: [Type] -- representation types of
148                      , repr_comps    :: [CompRepr]          -- components
149                      }
150 data ConRepr  = ConRepr DataCon ProdRepr
151
152 data SumRepr  = EmptySum
153               | UnarySum ConRepr
154               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
155                      , repr_psum_tc  :: TyCon  -- PData representation tycon
156                      , repr_sel_ty   :: Type   -- type of selector
157                      , repr_con_tys :: [Type]  -- representation types of
158                      , repr_cons     :: [ConRepr]           -- components
159                      }
160
161 tyConRepr :: TyCon -> VM SumRepr
162 tyConRepr tc = sum_repr (tyConDataCons tc)
163   where
164     sum_repr []    = return EmptySum
165     sum_repr [con] = liftM UnarySum (con_repr con)
166     sum_repr cons  = do
167                        rs     <- mapM con_repr cons
168                        sum_tc <- builtin (sumTyCon arity)
169                        tys    <- mapM conReprType rs
170                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
171                        sel_ty <- builtin (selTy arity)
172                        return $ Sum { repr_sum_tc  = sum_tc
173                                     , repr_psum_tc = psum_tc
174                                     , repr_sel_ty  = sel_ty
175                                     , repr_con_tys = tys
176                                     , repr_cons    = rs
177                                     }
178       where
179         arity = length cons
180
181     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
182
183     prod_repr []   = return EmptyProd
184     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
185     prod_repr tys  = do
186                        rs <- mapM comp_repr tys
187                        tup_tc <- builtin (prodTyCon arity)
188                        tys'    <- mapM compReprType rs
189                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
190                        return $ Prod { repr_tup_tc   = tup_tc
191                                      , repr_ptup_tc  = ptup_tc
192                                      , repr_comp_tys = tys'
193                                      , repr_comps    = rs
194                                      }
195       where
196         arity = length tys
197     
198     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
199                    `orElseV` return (Wrap ty)
200
201 sumReprType :: SumRepr -> VM Type
202 sumReprType EmptySum = voidType
203 sumReprType (UnarySum r) = conReprType r
204 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
205   = return $ mkTyConApp sum_tc tys
206
207 conReprType :: ConRepr -> VM Type
208 conReprType (ConRepr _ r) = prodReprType r
209
210 prodReprType :: ProdRepr -> VM Type
211 prodReprType EmptyProd = voidType
212 prodReprType (UnaryProd r) = compReprType r
213 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
214   = return $ mkTyConApp tup_tc tys
215
216 compReprType :: CompRepr -> VM Type
217 compReprType (Keep ty _) = return ty
218 compReprType (Wrap ty) = do
219                              wrap_tc <- builtin wrapTyCon
220                              return $ mkTyConApp wrap_tc [ty]
221
222 compOrigType :: CompRepr -> Type
223 compOrigType (Keep ty _) = ty
224 compOrigType (Wrap ty) = ty
225
226 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
227 buildToPRepr vect_tc repr_tc _ repr
228   = do
229       let arg_ty = mkTyConApp vect_tc ty_args
230       res_ty <- mkPReprType arg_ty
231       arg    <- newLocalVar (fsLit "x") arg_ty
232       result <- to_sum (Var arg) arg_ty res_ty repr
233       return $ Lam arg result
234   where
235     ty_args = mkTyVarTys (tyConTyVars vect_tc)
236
237     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
238
239     to_sum _ _ _ EmptySum
240       = do
241           void <- builtin voidVar
242           return $ wrap_repr_inst $ Var void
243
244     to_sum arg arg_ty res_ty (UnarySum r)
245       = do
246           (pat, vars, body) <- con_alt r
247           return $ mkWildCase arg arg_ty res_ty
248                    [(pat, vars, wrap_repr_inst body)]
249
250     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
251                                   , repr_con_tys = tys
252                                   , repr_cons    =  cons })
253       = do
254           alts <- mapM con_alt cons
255           let alts' = [(pat, vars, wrap_repr_inst
256                                    $ mkConApp sum_con (map Type tys ++ [body]))
257                         | ((pat, vars, body), sum_con)
258                             <- zip alts (tyConDataCons sum_tc)]
259           return $ mkWildCase arg arg_ty res_ty alts'
260
261     con_alt (ConRepr con r)
262       = do
263           (vars, body) <- to_prod r
264           return (DataAlt con, vars, body)
265
266     to_prod EmptyProd
267       = do
268           void <- builtin voidVar
269           return ([], Var void)
270
271     to_prod (UnaryProd comp)
272       = do
273           var  <- newLocalVar (fsLit "x") (compOrigType comp)
274           body <- to_comp (Var var) comp
275           return ([var], body)
276
277     to_prod(Prod { repr_tup_tc   = tup_tc
278                  , repr_comp_tys = tys
279                  , repr_comps    = comps })
280       = do
281           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
282           exprs <- zipWithM to_comp (map Var vars) comps
283           return (vars, mkConApp tup_con (map Type tys ++ exprs))
284       where
285         [tup_con] = tyConDataCons tup_tc
286
287     to_comp expr (Keep _ _) = return expr
288     to_comp expr (Wrap ty)  = do
289                                 wrap_tc <- builtin wrapTyCon
290                                 return $ wrapNewTypeBody wrap_tc [ty] expr
291
292
293 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
294 buildFromPRepr vect_tc repr_tc _ repr
295   = do
296       arg_ty <- mkPReprType res_ty
297       arg <- newLocalVar (fsLit "x") arg_ty
298
299       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
300                          repr
301       return $ Lam arg result
302   where
303     ty_args = mkTyVarTys (tyConTyVars vect_tc)
304     res_ty  = mkTyConApp vect_tc ty_args
305
306     from_sum _ EmptySum
307       = do
308           dummy <- builtin fromVoidVar
309           return $ Var dummy `App` Type res_ty
310
311     from_sum expr (UnarySum r) = from_con expr r
312     from_sum expr (Sum { repr_sum_tc  = sum_tc
313                        , repr_con_tys = tys
314                        , repr_cons    = cons })
315       = do
316           vars  <- newLocalVars (fsLit "x") tys
317           es    <- zipWithM from_con (map Var vars) cons
318           return $ mkWildCase expr (exprType expr) res_ty
319                    [(DataAlt con, [var], e)
320                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
321
322     from_con expr (ConRepr con r)
323       = from_prod expr (mkConApp con $ map Type ty_args) r
324
325     from_prod _ con EmptyProd = return con
326     from_prod expr con (UnaryProd r)
327       = do
328           e <- from_comp expr r
329           return $ con `App` e
330      
331     from_prod expr con (Prod { repr_tup_tc   = tup_tc
332                              , repr_comp_tys = tys
333                              , repr_comps    = comps
334                              })
335       = do
336           vars <- newLocalVars (fsLit "y") tys
337           es   <- zipWithM from_comp (map Var vars) comps
338           return $ mkWildCase expr (exprType expr) res_ty
339                    [(DataAlt tup_con, vars, con `mkApps` es)]
340       where
341         [tup_con] = tyConDataCons tup_tc  
342
343     from_comp expr (Keep _ _) = return expr
344     from_comp expr (Wrap ty)
345       = do
346           wrap <- builtin wrapTyCon
347           return $ unwrapNewTypeBody wrap [ty] expr
348
349
350 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
351 buildToArrPRepr vect_tc prepr_tc pdata_tc r
352   = do
353       arg_ty <- mkPDataType el_ty
354       res_ty <- mkPDataType =<< mkPReprType el_ty
355       arg    <- newLocalVar (fsLit "xs") arg_ty
356
357       pdata_co <- mkBuiltinCo pdataTyCon
358       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
359           co           = mkAppCoercion pdata_co
360                        . mkSymCoercion
361                        $ mkTyConApp repr_co ty_args
362
363           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
364
365       (vars, result) <- to_sum r
366
367       return . Lam arg
368              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
369                [(DataAlt pdata_dc, vars, mkCoerce co result)]
370   where
371     ty_args = mkTyVarTys $ tyConTyVars vect_tc
372     el_ty   = mkTyConApp vect_tc ty_args
373
374     [pdata_dc] = tyConDataCons pdata_tc
375
376
377     to_sum EmptySum = do
378                         pvoid <- builtin pvoidVar
379                         return ([], Var pvoid)
380     to_sum (UnarySum r) = to_con r
381     to_sum (Sum { repr_psum_tc = psum_tc
382                 , repr_sel_ty  = sel_ty
383                 , repr_con_tys = tys
384                 , repr_cons    = cons
385                 })
386       = do
387           (vars, exprs) <- mapAndUnzipM to_con cons
388           sel <- newLocalVar (fsLit "sel") sel_ty
389           return (sel : concat vars, mk_result (Var sel) exprs)
390       where
391         [psum_con] = tyConDataCons psum_tc
392         mk_result sel exprs = wrapFamInstBody psum_tc tys
393                             $ mkConApp psum_con
394                             $ map Type tys ++ (sel : exprs)
395
396     to_con (ConRepr _ r) = to_prod r
397
398     to_prod EmptyProd = do
399                           pvoid <- builtin pvoidVar
400                           return ([], Var pvoid)
401     to_prod (UnaryProd r)
402       = do
403           pty  <- mkPDataType (compOrigType r)
404           var  <- newLocalVar (fsLit "x") pty
405           expr <- to_comp (Var var) r
406           return ([var], expr)
407
408     to_prod (Prod { repr_ptup_tc  = ptup_tc
409                   , repr_comp_tys = tys
410                   , repr_comps    = comps })
411       = do
412           ptys <- mapM (mkPDataType . compOrigType) comps
413           vars <- newLocalVars (fsLit "x") ptys
414           es   <- zipWithM to_comp (map Var vars) comps
415           return (vars, mk_result es)
416       where
417         [ptup_con] = tyConDataCons ptup_tc
418         mk_result exprs = wrapFamInstBody ptup_tc tys
419                         $ mkConApp ptup_con
420                         $ map Type tys ++ exprs
421
422     to_comp expr (Keep _ _) = return expr
423
424     -- FIXME: this is bound to be wrong!
425     to_comp expr (Wrap ty)
426       = do
427           wrap_tc  <- builtin wrapTyCon
428           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
429           return $ wrapNewTypeBody pwrap_tc [ty] expr
430
431
432 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
433 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
434   = do
435       arg_ty <- mkPDataType =<< mkPReprType el_ty
436       res_ty <- mkPDataType el_ty
437       arg    <- newLocalVar (fsLit "xs") arg_ty
438
439       pdata_co <- mkBuiltinCo pdataTyCon
440       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
441           co           = mkAppCoercion pdata_co
442                        $ mkTyConApp repr_co var_tys
443
444           scrut  = mkCoerce co (Var arg)
445
446           mk_result args = wrapFamInstBody pdata_tc var_tys
447                          $ mkConApp pdata_con
448                          $ map Type var_tys ++ args
449
450       (expr, _) <- fixV $ \ ~(_, args) ->
451                      from_sum res_ty (mk_result args) scrut r
452
453       return $ Lam arg expr
454     
455       -- (args, mk) <- from_sum res_ty scrut r
456       
457       -- let result = wrapFamInstBody pdata_tc var_tys
458       --           . mkConApp pdata_dc
459       --           $ map Type var_tys ++ args
460
461       -- return $ Lam arg (mk result)
462   where
463     var_tys = mkTyVarTys $ tyConTyVars vect_tc
464     el_ty   = mkTyConApp vect_tc var_tys
465
466     [pdata_con] = tyConDataCons pdata_tc
467
468     from_sum _ res _ EmptySum = return (res, [])
469     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
470     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
471                                   , repr_sel_ty  = sel_ty
472                                   , repr_con_tys = tys
473                                   , repr_cons    = cons })
474       = do
475           sel  <- newLocalVar (fsLit "sel") sel_ty
476           ptys <- mapM mkPDataType tys
477           vars <- newLocalVars (fsLit "xs") ptys
478           (res', args) <- fold from_con res_ty res (map Var vars) cons
479           let scrut = unwrapFamInstScrut psum_tc tys expr
480               body  = mkWildCase scrut (exprType scrut) res_ty
481                       [(DataAlt psum_con, sel : vars, res')]
482           return (body, Var sel : args)
483       where
484         [psum_con] = tyConDataCons psum_tc
485
486
487     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
488
489     from_prod _ res _ EmptyProd = return (res, [])
490     from_prod res_ty res expr (UnaryProd r)
491       = from_comp res_ty res expr r
492     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
493                                     , repr_comp_tys = tys
494                                     , repr_comps    = comps })
495       = do
496           ptys <- mapM mkPDataType tys
497           vars <- newLocalVars (fsLit "ys") ptys
498           (res', args) <- fold from_comp res_ty res (map Var vars) comps
499           let scrut = unwrapFamInstScrut ptup_tc tys expr
500               body  = mkWildCase scrut (exprType scrut) res_ty
501                       [(DataAlt ptup_con, vars, res')]
502           return (body, args)
503       where
504         [ptup_con] = tyConDataCons ptup_tc
505
506     from_comp _ res expr (Keep _ _) = return (res, [expr])
507     from_comp _ res expr (Wrap ty)
508       = do
509           wrap_tc  <- builtin wrapTyCon
510           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
511           return (res, [unwrapNewTypeBody pwrap_tc [ty]
512                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
513
514     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
515       where
516         f' (expr, r) (res, args) = do
517                                      (res', args') <- f res_ty res expr r
518                                      return (res', args' ++ args)
519
520 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
521 buildPRDict vect_tc prepr_tc _ r
522   = do
523       dict <- sum_dict r
524       pr_co <- mkBuiltinCo prTyCon
525       let co = mkAppCoercion pr_co
526              . mkSymCoercion
527              $ mkTyConApp arg_co ty_args
528       return (mkCoerce co dict)
529   where
530     ty_args = mkTyVarTys (tyConTyVars vect_tc)
531     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
532
533     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
534     sum_dict (UnarySum r) = con_dict r
535     sum_dict (Sum { repr_sum_tc  = sum_tc
536                   , repr_con_tys = tys
537                   , repr_cons    = cons
538                   })
539       = do
540           dicts <- mapM con_dict cons
541           dfun  <- prDFunOfTyCon sum_tc
542           return $ dfun `mkTyApps` tys `mkApps` dicts
543
544     con_dict (ConRepr _ r) = prod_dict r
545
546     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
547     prod_dict (UnaryProd r) = comp_dict r
548     prod_dict (Prod { repr_tup_tc   = tup_tc
549                     , repr_comp_tys = tys
550                     , repr_comps    = comps })
551       = do
552           dicts <- mapM comp_dict comps
553           dfun <- prDFunOfTyCon tup_tc
554           return $ dfun `mkTyApps` tys `mkApps` dicts
555
556     comp_dict (Keep _ pr) = return pr
557     comp_dict (Wrap ty)   = wrapPR ty
558
559
560 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
561 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
562   do
563     name' <- cloneName mkPDataTyConOcc orig_name
564     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
565     pdata <- builtin pdataTyCon
566
567     liftDs $ buildAlgTyCon name'
568                            tyvars
569                            []          -- no stupid theta
570                            rhs
571                            rec_flag    -- FIXME: is this ok?
572                            False       -- FIXME: no generics
573                            False       -- not GADT syntax
574                            (Just $ mk_fam_inst pdata vect_tc)
575   where
576     orig_name = tyConName orig_tc
577     tyvars = tyConTyVars vect_tc
578     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
579
580
581 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
582 buildPDataTyConRhs orig_name vect_tc repr_tc repr
583   = do
584       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
585       return $ DataTyCon { data_cons = [data_con], is_enum = False }
586
587 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
588 buildPDataDataCon orig_name vect_tc repr_tc repr
589   = do
590       dc_name  <- cloneName mkPDataDataConOcc orig_name
591       comp_tys <- sum_tys repr
592
593       liftDs $ buildDataCon dc_name
594                             False                  -- not infix
595                             (map (const HsNoBang) comp_tys)
596                             []                     -- no field labels
597                             tvs
598                             []                     -- no existentials
599                             []                     -- no eq spec
600                             []                     -- no context
601                             comp_tys
602                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
603                             repr_tc
604   where
605     tvs   = tyConTyVars vect_tc
606
607     sum_tys EmptySum = return []
608     sum_tys (UnarySum r) = con_tys r
609     sum_tys (Sum { repr_sel_ty = sel_ty
610                  , repr_cons   = cons })
611       = liftM (sel_ty :) (concatMapM con_tys cons)
612
613     con_tys (ConRepr _ r) = prod_tys r
614
615     prod_tys EmptyProd = return []
616     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
617     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
618
619     comp_ty r = mkPDataType (compOrigType r)
620
621
622 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
623                    -> VM Var
624 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
625   = do
626       vectDataConWorkers orig_tc vect_tc pdata_tc
627       buildPADict vect_tc prepr_tc pdata_tc repr
628
629 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
630 vectDataConWorkers orig_tc vect_tc arr_tc
631   = do
632       bs <- sequence
633           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
634           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
635                                  rep_tys
636                                  (inits rep_tys)
637                                  (tail $ tails rep_tys)
638       mapM_ (uncurry hoistBinding) bs
639   where
640     tyvars   = tyConTyVars vect_tc
641     var_tys  = mkTyVarTys tyvars
642     ty_args  = map Type var_tys
643     res_ty   = mkTyConApp vect_tc var_tys
644
645     cons     = tyConDataCons vect_tc
646     arity    = length cons
647     [arr_dc] = tyConDataCons arr_tc
648
649     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
650
651
652     mk_data_con con tys pre post
653       = liftM2 (,) (vect_data_con con)
654                    (lift_data_con tys pre post (mkDataConTag con))
655
656     sel_replicate len tag
657       | arity > 1 = do
658                       rep <- builtin (selReplicate arity)
659                       return [rep `mkApps` [len, tag]]
660
661       | otherwise = return []
662
663     vect_data_con con = return $ mkConApp con ty_args
664     lift_data_con tys pre_tys post_tys tag
665       = do
666           len  <- builtin liftingContext
667           args <- mapM (newLocalVar (fsLit "xs"))
668                   =<< mapM mkPDataType tys
669
670           sel  <- sel_replicate (Var len) tag
671
672           pre   <- mapM emptyPD (concat pre_tys)
673           post  <- mapM emptyPD (concat post_tys)
674
675           return . mkLams (len : args)
676                  . wrapFamInstBody arr_tc var_tys
677                  . mkConApp arr_dc
678                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
679
680     def_worker data_con arg_tys mk_body
681       = do
682           arity <- polyArity tyvars
683           body <- closedV
684                 . inBind orig_worker
685                 . polyAbstract tyvars $ \args ->
686                   liftM (mkLams (tyvars ++ args) . vectorised)
687                 $ buildClosures tyvars [] arg_tys res_ty mk_body
688
689           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
690           let vect_worker = raw_worker `setIdUnfolding`
691                               mkInlineRule body (Just arity)
692           defGlobalVar orig_worker vect_worker
693           return (vect_worker, body)
694       where
695         orig_worker = dataConWorkId data_con
696
697 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
698 buildPADict vect_tc prepr_tc arr_tc repr
699   = polyAbstract tvs $ \args ->
700     do
701       method_ids <- mapM (method args) paMethods
702
703       pa_tc  <- builtin paTyCon
704       pa_dc  <- builtin paDataCon
705       let dict = mkLams (tvs ++ args)
706                $ mkConApp pa_dc
707                $ Type inst_ty : map (method_call args) method_ids
708
709           dfun_ty = mkForAllTys tvs
710                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
711
712       raw_dfun <- newExportedVar dfun_name dfun_ty
713       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
714                           `setInlinePragma` dfunInlinePragma
715
716       hoistBinding dfun dict
717       return dfun
718   where
719     tvs = tyConTyVars vect_tc
720     arg_tys = mkTyVarTys tvs
721     inst_ty = mkTyConApp vect_tc arg_tys
722
723     dfun_name = mkPADFunOcc (getOccName vect_tc)
724
725     method args (name, build)
726       = localV
727       $ do
728           expr <- build vect_tc prepr_tc arr_tc repr
729           let body = mkLams (tvs ++ args) expr
730           raw_var <- newExportedVar (method_name name) (exprType body)
731           let var = raw_var
732                       `setIdUnfolding` mkInlineRule body (Just (length args))
733                       `setInlinePragma` alwaysInlinePragma
734           hoistBinding var body
735           return var
736
737     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
738
739     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
740
741
742 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
743 paMethods = [("dictPRepr",    buildPRDict),
744              ("toPRepr",      buildToPRepr),
745              ("fromPRepr",    buildFromPRepr),
746              ("toArrPRepr",   buildToArrPRepr),
747              ("fromArrPRepr", buildFromArrPRepr)]
748
749
750 -- ----------------------------------------------------------------------------
751 -- Conversions
752
753 -- | Build an expression that calls the vectorised version of some 
754 --   function from a `Closure`.
755 --
756 --   For example
757 --   @   
758 --      \(x :: Double) -> 
759 --      \(y :: Double) -> 
760 --      ($v_foo $: x) $: y
761 --   @
762 --
763 --   We use the type of the original binding to work out how many
764 --   outer lambdas to add.
765 --
766 fromVect 
767         :: Type         -- ^ The type of the original binding.
768         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
769         -> VM CoreExpr
770         
771 -- Convert the type to the core view if it isn't already.
772 fromVect ty expr 
773         | Just ty' <- coreView ty 
774         = fromVect ty' expr
775
776 -- For each function constructor in the original type we add an outer 
777 -- lambda to bind the parameter variable, and an inner application of it.
778 fromVect (FunTy arg_ty res_ty) expr
779   = do
780       arg     <- newLocalVar (fsLit "x") arg_ty
781       varg    <- toVect arg_ty (Var arg)
782       varg_ty <- vectType arg_ty
783       vres_ty <- vectType res_ty
784       apply   <- builtin applyVar
785       body    <- fromVect res_ty
786                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
787       return $ Lam arg body
788
789 -- If the type isn't a function then it's time to call on the closure.
790 fromVect ty expr
791   = identityConv ty >> return expr
792
793
794 -- TODO: What is this really doing?
795 toVect :: Type -> CoreExpr -> VM CoreExpr
796 toVect ty expr = identityConv ty >> return expr
797
798
799 -- | Check that we have the vectorised versions of all the
800 --   type constructors in this type.
801 identityConv :: Type -> VM ()
802 identityConv ty 
803   | Just ty' <- coreView ty 
804   = identityConv ty'
805
806 identityConv (TyConApp tycon tys)
807  = do mapM_ identityConv tys
808       identityConvTyCon tycon
809
810 identityConv _ = noV
811
812
813 -- | Check that we have the vectorised version of this type constructor.
814 identityConvTyCon :: TyCon -> VM ()
815 identityConvTyCon tc
816   | isBoxedTupleTyCon tc = return ()
817   | isUnLiftedTyCon tc   = return ()
818   | otherwise 
819   = do tc' <- maybeV (lookupTyCon tc)
820        if tc == tc' then return () else noV
821
822