046acb9db58f018aa468aba6a3594177bfd2ad46
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.Type
15 import Vectorise.Type.TyConDecl
16 import Vectorise.Type.Classify
17 import Vectorise.Utils.Closure
18 import Vectorise.Utils.Hoisting
19
20 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
21 import BasicTypes
22 import CoreSyn
23 import CoreUtils
24 import CoreUnfold
25 import MkCore            ( mkWildCase )
26 import BuildTyCl
27 import DataCon
28 import TyCon
29 import Type
30 import TypeRep
31 import Coercion
32 import FamInstEnv        ( FamInst, mkLocalFamInst )
33 import OccName
34 import Id
35 import MkId
36 import Var
37 import Name              ( Name, getOccName )
38 import NameEnv
39
40 import Unique
41 import UniqFM
42 import Util
43
44 import Outputable
45 import FastString
46
47 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
48 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
49 import Data.List
50
51 debug           = False
52 dtrace s x      = if debug then pprTrace "VectType" s x else x
53
54
55 -- ----------------------------------------------------------------------------
56 -- Type definitions
57
58
59 -- | Vectorise a type environment.
60 --   The type environment contains all the type things defined in a module.
61 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
62 vectTypeEnv env
63  = dtrace (ppr env)
64  $ do
65       cs <- readGEnv $ mk_map . global_tycons
66
67       -- Split the list of TyCons into the ones we have to vectorise vs the
68       -- ones we can pass through unchanged. We also pass through algebraic 
69       -- types that use non Haskell98 features, as we don't handle those.
70       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
71           keep_dcs             = concatMap tyConDataCons keep_tcs
72
73       zipWithM_ defTyCon   keep_tcs keep_tcs
74       zipWithM_ defDataCon keep_dcs keep_dcs
75
76       new_tcs <- vectTyConDecls conv_tcs
77
78       let orig_tcs = keep_tcs ++ conv_tcs
79
80       -- We don't need to make new representation types for dictionary
81       -- constructors. The constructors are always fully applied, and we don't 
82       -- need to lift them to arrays as a dictionary of a particular type
83       -- always has the same value.
84       let vect_tcs = filter (not . isClassTyCon) 
85                    $ keep_tcs ++ new_tcs
86
87       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
88         do
89           defTyConPAs (zipLazy vect_tcs dfuns')
90           reprs     <- mapM tyConRepr vect_tcs
91           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
92           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
93
94           dfuns     <- sequence 
95                     $  zipWith5 buildTyConBindings
96                                orig_tcs
97                                vect_tcs
98                                repr_tcs
99                                pdata_tcs
100                                reprs
101
102           binds     <- takeHoisted
103           return (dfuns, binds, repr_tcs ++ pdata_tcs)
104
105       let all_new_tcs = new_tcs ++ inst_tcs
106
107       let new_env = extendTypeEnvList env
108                        (map ATyCon all_new_tcs
109                         ++ [ADataCon dc | tc <- all_new_tcs
110                                         , dc <- tyConDataCons tc])
111
112       return (new_env, map mkLocalFamInst inst_tcs, binds)
113   where
114     tycons = typeEnvTyCons env
115     groups = tyConGroups tycons
116
117     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
118
119
120 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
121 mk_fam_inst fam_tc arg_tc
122   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
123
124
125 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
126 buildPReprTyCon orig_tc vect_tc repr
127   = do
128       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
129       -- rhs_ty   <- buildPReprType vect_tc
130       rhs_ty   <- sumReprType repr
131       prepr_tc <- builtin preprTyCon
132       liftDs $ buildSynTyCon name
133                              tyvars
134                              (SynonymTyCon rhs_ty)
135                              (typeKind rhs_ty)
136                              (Just $ mk_fam_inst prepr_tc vect_tc)
137   where
138     tyvars = tyConTyVars vect_tc
139
140 data CompRepr = Keep Type
141                      CoreExpr     -- PR dictionary for the type
142               | Wrap Type
143
144 data ProdRepr = EmptyProd
145               | UnaryProd CompRepr
146               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
147                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
148                      , repr_comp_tys :: [Type] -- representation types of
149                      , repr_comps    :: [CompRepr]          -- components
150                      }
151 data ConRepr  = ConRepr DataCon ProdRepr
152
153 data SumRepr  = EmptySum
154               | UnarySum ConRepr
155               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
156                      , repr_psum_tc  :: TyCon  -- PData representation tycon
157                      , repr_sel_ty   :: Type   -- type of selector
158                      , repr_con_tys :: [Type]  -- representation types of
159                      , repr_cons     :: [ConRepr]           -- components
160                      }
161
162 tyConRepr :: TyCon -> VM SumRepr
163 tyConRepr tc = sum_repr (tyConDataCons tc)
164   where
165     sum_repr []    = return EmptySum
166     sum_repr [con] = liftM UnarySum (con_repr con)
167     sum_repr cons  = do
168                        rs     <- mapM con_repr cons
169                        sum_tc <- builtin (sumTyCon arity)
170                        tys    <- mapM conReprType rs
171                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
172                        sel_ty <- builtin (selTy arity)
173                        return $ Sum { repr_sum_tc  = sum_tc
174                                     , repr_psum_tc = psum_tc
175                                     , repr_sel_ty  = sel_ty
176                                     , repr_con_tys = tys
177                                     , repr_cons    = rs
178                                     }
179       where
180         arity = length cons
181
182     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
183
184     prod_repr []   = return EmptyProd
185     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
186     prod_repr tys  = do
187                        rs <- mapM comp_repr tys
188                        tup_tc <- builtin (prodTyCon arity)
189                        tys'    <- mapM compReprType rs
190                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
191                        return $ Prod { repr_tup_tc   = tup_tc
192                                      , repr_ptup_tc  = ptup_tc
193                                      , repr_comp_tys = tys'
194                                      , repr_comps    = rs
195                                      }
196       where
197         arity = length tys
198     
199     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
200                    `orElseV` return (Wrap ty)
201
202 sumReprType :: SumRepr -> VM Type
203 sumReprType EmptySum = voidType
204 sumReprType (UnarySum r) = conReprType r
205 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
206   = return $ mkTyConApp sum_tc tys
207
208 conReprType :: ConRepr -> VM Type
209 conReprType (ConRepr _ r) = prodReprType r
210
211 prodReprType :: ProdRepr -> VM Type
212 prodReprType EmptyProd = voidType
213 prodReprType (UnaryProd r) = compReprType r
214 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
215   = return $ mkTyConApp tup_tc tys
216
217 compReprType :: CompRepr -> VM Type
218 compReprType (Keep ty _) = return ty
219 compReprType (Wrap ty) = do
220                              wrap_tc <- builtin wrapTyCon
221                              return $ mkTyConApp wrap_tc [ty]
222
223 compOrigType :: CompRepr -> Type
224 compOrigType (Keep ty _) = ty
225 compOrigType (Wrap ty) = ty
226
227 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
228 buildToPRepr vect_tc repr_tc _ repr
229   = do
230       let arg_ty = mkTyConApp vect_tc ty_args
231       res_ty <- mkPReprType arg_ty
232       arg    <- newLocalVar (fsLit "x") arg_ty
233       result <- to_sum (Var arg) arg_ty res_ty repr
234       return $ Lam arg result
235   where
236     ty_args = mkTyVarTys (tyConTyVars vect_tc)
237
238     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
239
240     to_sum _ _ _ EmptySum
241       = do
242           void <- builtin voidVar
243           return $ wrap_repr_inst $ Var void
244
245     to_sum arg arg_ty res_ty (UnarySum r)
246       = do
247           (pat, vars, body) <- con_alt r
248           return $ mkWildCase arg arg_ty res_ty
249                    [(pat, vars, wrap_repr_inst body)]
250
251     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
252                                   , repr_con_tys = tys
253                                   , repr_cons    =  cons })
254       = do
255           alts <- mapM con_alt cons
256           let alts' = [(pat, vars, wrap_repr_inst
257                                    $ mkConApp sum_con (map Type tys ++ [body]))
258                         | ((pat, vars, body), sum_con)
259                             <- zip alts (tyConDataCons sum_tc)]
260           return $ mkWildCase arg arg_ty res_ty alts'
261
262     con_alt (ConRepr con r)
263       = do
264           (vars, body) <- to_prod r
265           return (DataAlt con, vars, body)
266
267     to_prod EmptyProd
268       = do
269           void <- builtin voidVar
270           return ([], Var void)
271
272     to_prod (UnaryProd comp)
273       = do
274           var  <- newLocalVar (fsLit "x") (compOrigType comp)
275           body <- to_comp (Var var) comp
276           return ([var], body)
277
278     to_prod(Prod { repr_tup_tc   = tup_tc
279                  , repr_comp_tys = tys
280                  , repr_comps    = comps })
281       = do
282           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
283           exprs <- zipWithM to_comp (map Var vars) comps
284           return (vars, mkConApp tup_con (map Type tys ++ exprs))
285       where
286         [tup_con] = tyConDataCons tup_tc
287
288     to_comp expr (Keep _ _) = return expr
289     to_comp expr (Wrap ty)  = do
290                                 wrap_tc <- builtin wrapTyCon
291                                 return $ wrapNewTypeBody wrap_tc [ty] expr
292
293
294 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
295 buildFromPRepr vect_tc repr_tc _ repr
296   = do
297       arg_ty <- mkPReprType res_ty
298       arg <- newLocalVar (fsLit "x") arg_ty
299
300       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
301                          repr
302       return $ Lam arg result
303   where
304     ty_args = mkTyVarTys (tyConTyVars vect_tc)
305     res_ty  = mkTyConApp vect_tc ty_args
306
307     from_sum _ EmptySum
308       = do
309           dummy <- builtin fromVoidVar
310           return $ Var dummy `App` Type res_ty
311
312     from_sum expr (UnarySum r) = from_con expr r
313     from_sum expr (Sum { repr_sum_tc  = sum_tc
314                        , repr_con_tys = tys
315                        , repr_cons    = cons })
316       = do
317           vars  <- newLocalVars (fsLit "x") tys
318           es    <- zipWithM from_con (map Var vars) cons
319           return $ mkWildCase expr (exprType expr) res_ty
320                    [(DataAlt con, [var], e)
321                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
322
323     from_con expr (ConRepr con r)
324       = from_prod expr (mkConApp con $ map Type ty_args) r
325
326     from_prod _ con EmptyProd = return con
327     from_prod expr con (UnaryProd r)
328       = do
329           e <- from_comp expr r
330           return $ con `App` e
331      
332     from_prod expr con (Prod { repr_tup_tc   = tup_tc
333                              , repr_comp_tys = tys
334                              , repr_comps    = comps
335                              })
336       = do
337           vars <- newLocalVars (fsLit "y") tys
338           es   <- zipWithM from_comp (map Var vars) comps
339           return $ mkWildCase expr (exprType expr) res_ty
340                    [(DataAlt tup_con, vars, con `mkApps` es)]
341       where
342         [tup_con] = tyConDataCons tup_tc  
343
344     from_comp expr (Keep _ _) = return expr
345     from_comp expr (Wrap ty)
346       = do
347           wrap <- builtin wrapTyCon
348           return $ unwrapNewTypeBody wrap [ty] expr
349
350
351 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
352 buildToArrPRepr vect_tc prepr_tc pdata_tc r
353   = do
354       arg_ty <- mkPDataType el_ty
355       res_ty <- mkPDataType =<< mkPReprType el_ty
356       arg    <- newLocalVar (fsLit "xs") arg_ty
357
358       pdata_co <- mkBuiltinCo pdataTyCon
359       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
360           co           = mkAppCoercion pdata_co
361                        . mkSymCoercion
362                        $ mkTyConApp repr_co ty_args
363
364           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
365
366       (vars, result) <- to_sum r
367
368       return . Lam arg
369              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
370                [(DataAlt pdata_dc, vars, mkCoerce co result)]
371   where
372     ty_args = mkTyVarTys $ tyConTyVars vect_tc
373     el_ty   = mkTyConApp vect_tc ty_args
374
375     [pdata_dc] = tyConDataCons pdata_tc
376
377
378     to_sum EmptySum = do
379                         pvoid <- builtin pvoidVar
380                         return ([], Var pvoid)
381     to_sum (UnarySum r) = to_con r
382     to_sum (Sum { repr_psum_tc = psum_tc
383                 , repr_sel_ty  = sel_ty
384                 , repr_con_tys = tys
385                 , repr_cons    = cons
386                 })
387       = do
388           (vars, exprs) <- mapAndUnzipM to_con cons
389           sel <- newLocalVar (fsLit "sel") sel_ty
390           return (sel : concat vars, mk_result (Var sel) exprs)
391       where
392         [psum_con] = tyConDataCons psum_tc
393         mk_result sel exprs = wrapFamInstBody psum_tc tys
394                             $ mkConApp psum_con
395                             $ map Type tys ++ (sel : exprs)
396
397     to_con (ConRepr _ r) = to_prod r
398
399     to_prod EmptyProd = do
400                           pvoid <- builtin pvoidVar
401                           return ([], Var pvoid)
402     to_prod (UnaryProd r)
403       = do
404           pty  <- mkPDataType (compOrigType r)
405           var  <- newLocalVar (fsLit "x") pty
406           expr <- to_comp (Var var) r
407           return ([var], expr)
408
409     to_prod (Prod { repr_ptup_tc  = ptup_tc
410                   , repr_comp_tys = tys
411                   , repr_comps    = comps })
412       = do
413           ptys <- mapM (mkPDataType . compOrigType) comps
414           vars <- newLocalVars (fsLit "x") ptys
415           es   <- zipWithM to_comp (map Var vars) comps
416           return (vars, mk_result es)
417       where
418         [ptup_con] = tyConDataCons ptup_tc
419         mk_result exprs = wrapFamInstBody ptup_tc tys
420                         $ mkConApp ptup_con
421                         $ map Type tys ++ exprs
422
423     to_comp expr (Keep _ _) = return expr
424
425     -- FIXME: this is bound to be wrong!
426     to_comp expr (Wrap ty)
427       = do
428           wrap_tc  <- builtin wrapTyCon
429           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
430           return $ wrapNewTypeBody pwrap_tc [ty] expr
431
432
433 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
434 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
435   = do
436       arg_ty <- mkPDataType =<< mkPReprType el_ty
437       res_ty <- mkPDataType el_ty
438       arg    <- newLocalVar (fsLit "xs") arg_ty
439
440       pdata_co <- mkBuiltinCo pdataTyCon
441       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
442           co           = mkAppCoercion pdata_co
443                        $ mkTyConApp repr_co var_tys
444
445           scrut  = mkCoerce co (Var arg)
446
447           mk_result args = wrapFamInstBody pdata_tc var_tys
448                          $ mkConApp pdata_con
449                          $ map Type var_tys ++ args
450
451       (expr, _) <- fixV $ \ ~(_, args) ->
452                      from_sum res_ty (mk_result args) scrut r
453
454       return $ Lam arg expr
455     
456       -- (args, mk) <- from_sum res_ty scrut r
457       
458       -- let result = wrapFamInstBody pdata_tc var_tys
459       --           . mkConApp pdata_dc
460       --           $ map Type var_tys ++ args
461
462       -- return $ Lam arg (mk result)
463   where
464     var_tys = mkTyVarTys $ tyConTyVars vect_tc
465     el_ty   = mkTyConApp vect_tc var_tys
466
467     [pdata_con] = tyConDataCons pdata_tc
468
469     from_sum _ res _ EmptySum = return (res, [])
470     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
471     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
472                                   , repr_sel_ty  = sel_ty
473                                   , repr_con_tys = tys
474                                   , repr_cons    = cons })
475       = do
476           sel  <- newLocalVar (fsLit "sel") sel_ty
477           ptys <- mapM mkPDataType tys
478           vars <- newLocalVars (fsLit "xs") ptys
479           (res', args) <- fold from_con res_ty res (map Var vars) cons
480           let scrut = unwrapFamInstScrut psum_tc tys expr
481               body  = mkWildCase scrut (exprType scrut) res_ty
482                       [(DataAlt psum_con, sel : vars, res')]
483           return (body, Var sel : args)
484       where
485         [psum_con] = tyConDataCons psum_tc
486
487
488     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
489
490     from_prod _ res _ EmptyProd = return (res, [])
491     from_prod res_ty res expr (UnaryProd r)
492       = from_comp res_ty res expr r
493     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
494                                     , repr_comp_tys = tys
495                                     , repr_comps    = comps })
496       = do
497           ptys <- mapM mkPDataType tys
498           vars <- newLocalVars (fsLit "ys") ptys
499           (res', args) <- fold from_comp res_ty res (map Var vars) comps
500           let scrut = unwrapFamInstScrut ptup_tc tys expr
501               body  = mkWildCase scrut (exprType scrut) res_ty
502                       [(DataAlt ptup_con, vars, res')]
503           return (body, args)
504       where
505         [ptup_con] = tyConDataCons ptup_tc
506
507     from_comp _ res expr (Keep _ _) = return (res, [expr])
508     from_comp _ res expr (Wrap ty)
509       = do
510           wrap_tc  <- builtin wrapTyCon
511           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
512           return (res, [unwrapNewTypeBody pwrap_tc [ty]
513                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
514
515     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
516       where
517         f' (expr, r) (res, args) = do
518                                      (res', args') <- f res_ty res expr r
519                                      return (res', args' ++ args)
520
521 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
522 buildPRDict vect_tc prepr_tc _ r
523   = do
524       dict <- sum_dict r
525       pr_co <- mkBuiltinCo prTyCon
526       let co = mkAppCoercion pr_co
527              . mkSymCoercion
528              $ mkTyConApp arg_co ty_args
529       return (mkCoerce co dict)
530   where
531     ty_args = mkTyVarTys (tyConTyVars vect_tc)
532     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
533
534     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
535     sum_dict (UnarySum r) = con_dict r
536     sum_dict (Sum { repr_sum_tc  = sum_tc
537                   , repr_con_tys = tys
538                   , repr_cons    = cons
539                   })
540       = do
541           dicts <- mapM con_dict cons
542           dfun  <- prDFunOfTyCon sum_tc
543           return $ dfun `mkTyApps` tys `mkApps` dicts
544
545     con_dict (ConRepr _ r) = prod_dict r
546
547     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
548     prod_dict (UnaryProd r) = comp_dict r
549     prod_dict (Prod { repr_tup_tc   = tup_tc
550                     , repr_comp_tys = tys
551                     , repr_comps    = comps })
552       = do
553           dicts <- mapM comp_dict comps
554           dfun <- prDFunOfTyCon tup_tc
555           return $ dfun `mkTyApps` tys `mkApps` dicts
556
557     comp_dict (Keep _ pr) = return pr
558     comp_dict (Wrap ty)   = wrapPR ty
559
560
561 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
562 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
563   do
564     name' <- cloneName mkPDataTyConOcc orig_name
565     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
566     pdata <- builtin pdataTyCon
567
568     liftDs $ buildAlgTyCon name'
569                            tyvars
570                            []          -- no stupid theta
571                            rhs
572                            rec_flag    -- FIXME: is this ok?
573                            False       -- FIXME: no generics
574                            False       -- not GADT syntax
575                            (Just $ mk_fam_inst pdata vect_tc)
576   where
577     orig_name = tyConName orig_tc
578     tyvars = tyConTyVars vect_tc
579     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
580
581
582 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
583 buildPDataTyConRhs orig_name vect_tc repr_tc repr
584   = do
585       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
586       return $ DataTyCon { data_cons = [data_con], is_enum = False }
587
588 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
589 buildPDataDataCon orig_name vect_tc repr_tc repr
590   = do
591       dc_name  <- cloneName mkPDataDataConOcc orig_name
592       comp_tys <- sum_tys repr
593
594       liftDs $ buildDataCon dc_name
595                             False                  -- not infix
596                             (map (const HsNoBang) comp_tys)
597                             []                     -- no field labels
598                             tvs
599                             []                     -- no existentials
600                             []                     -- no eq spec
601                             []                     -- no context
602                             comp_tys
603                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
604                             repr_tc
605   where
606     tvs   = tyConTyVars vect_tc
607
608     sum_tys EmptySum = return []
609     sum_tys (UnarySum r) = con_tys r
610     sum_tys (Sum { repr_sel_ty = sel_ty
611                  , repr_cons   = cons })
612       = liftM (sel_ty :) (concatMapM con_tys cons)
613
614     con_tys (ConRepr _ r) = prod_tys r
615
616     prod_tys EmptyProd = return []
617     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
618     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
619
620     comp_ty r = mkPDataType (compOrigType r)
621
622
623 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
624                    -> VM Var
625 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
626   = do
627       vectDataConWorkers orig_tc vect_tc pdata_tc
628       buildPADict vect_tc prepr_tc pdata_tc repr
629
630 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
631 vectDataConWorkers orig_tc vect_tc arr_tc
632   = do
633       bs <- sequence
634           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
635           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
636                                  rep_tys
637                                  (inits rep_tys)
638                                  (tail $ tails rep_tys)
639       mapM_ (uncurry hoistBinding) bs
640   where
641     tyvars   = tyConTyVars vect_tc
642     var_tys  = mkTyVarTys tyvars
643     ty_args  = map Type var_tys
644     res_ty   = mkTyConApp vect_tc var_tys
645
646     cons     = tyConDataCons vect_tc
647     arity    = length cons
648     [arr_dc] = tyConDataCons arr_tc
649
650     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
651
652
653     mk_data_con con tys pre post
654       = liftM2 (,) (vect_data_con con)
655                    (lift_data_con tys pre post (mkDataConTag con))
656
657     sel_replicate len tag
658       | arity > 1 = do
659                       rep <- builtin (selReplicate arity)
660                       return [rep `mkApps` [len, tag]]
661
662       | otherwise = return []
663
664     vect_data_con con = return $ mkConApp con ty_args
665     lift_data_con tys pre_tys post_tys tag
666       = do
667           len  <- builtin liftingContext
668           args <- mapM (newLocalVar (fsLit "xs"))
669                   =<< mapM mkPDataType tys
670
671           sel  <- sel_replicate (Var len) tag
672
673           pre   <- mapM emptyPD (concat pre_tys)
674           post  <- mapM emptyPD (concat post_tys)
675
676           return . mkLams (len : args)
677                  . wrapFamInstBody arr_tc var_tys
678                  . mkConApp arr_dc
679                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
680
681     def_worker data_con arg_tys mk_body
682       = do
683           arity <- polyArity tyvars
684           body <- closedV
685                 . inBind orig_worker
686                 . polyAbstract tyvars $ \args ->
687                   liftM (mkLams (tyvars ++ args) . vectorised)
688                 $ buildClosures tyvars [] arg_tys res_ty mk_body
689
690           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
691           let vect_worker = raw_worker `setIdUnfolding`
692                               mkInlineRule body (Just arity)
693           defGlobalVar orig_worker vect_worker
694           return (vect_worker, body)
695       where
696         orig_worker = dataConWorkId data_con
697
698 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
699 buildPADict vect_tc prepr_tc arr_tc repr
700   = polyAbstract tvs $ \args ->
701     do
702       method_ids <- mapM (method args) paMethods
703
704       pa_tc  <- builtin paTyCon
705       pa_dc  <- builtin paDataCon
706       let dict = mkLams (tvs ++ args)
707                $ mkConApp pa_dc
708                $ Type inst_ty : map (method_call args) method_ids
709
710           dfun_ty = mkForAllTys tvs
711                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
712
713       raw_dfun <- newExportedVar dfun_name dfun_ty
714       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
715                           `setInlinePragma` dfunInlinePragma
716
717       hoistBinding dfun dict
718       return dfun
719   where
720     tvs = tyConTyVars vect_tc
721     arg_tys = mkTyVarTys tvs
722     inst_ty = mkTyConApp vect_tc arg_tys
723
724     dfun_name = mkPADFunOcc (getOccName vect_tc)
725
726     method args (name, build)
727       = localV
728       $ do
729           expr <- build vect_tc prepr_tc arr_tc repr
730           let body = mkLams (tvs ++ args) expr
731           raw_var <- newExportedVar (method_name name) (exprType body)
732           let var = raw_var
733                       `setIdUnfolding` mkInlineRule body (Just (length args))
734                       `setInlinePragma` alwaysInlinePragma
735           hoistBinding var body
736           return var
737
738     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
739
740     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
741
742
743 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
744 paMethods = [("dictPRepr",    buildPRDict),
745              ("toPRepr",      buildToPRepr),
746              ("fromPRepr",    buildFromPRepr),
747              ("toArrPRepr",   buildToArrPRepr),
748              ("fromArrPRepr", buildFromArrPRepr)]
749
750
751 -- ----------------------------------------------------------------------------
752 -- Conversions
753
754 -- | Build an expression that calls the vectorised version of some 
755 --   function from a `Closure`.
756 --
757 --   For example
758 --   @   
759 --      \(x :: Double) -> 
760 --      \(y :: Double) -> 
761 --      ($v_foo $: x) $: y
762 --   @
763 --
764 --   We use the type of the original binding to work out how many
765 --   outer lambdas to add.
766 --
767 fromVect 
768         :: Type         -- ^ The type of the original binding.
769         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
770         -> VM CoreExpr
771         
772 -- Convert the type to the core view if it isn't already.
773 fromVect ty expr 
774         | Just ty' <- coreView ty 
775         = fromVect ty' expr
776
777 -- For each function constructor in the original type we add an outer 
778 -- lambda to bind the parameter variable, and an inner application of it.
779 fromVect (FunTy arg_ty res_ty) expr
780   = do
781       arg     <- newLocalVar (fsLit "x") arg_ty
782       varg    <- toVect arg_ty (Var arg)
783       varg_ty <- vectType arg_ty
784       vres_ty <- vectType res_ty
785       apply   <- builtin applyVar
786       body    <- fromVect res_ty
787                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
788       return $ Lam arg body
789
790 -- If the type isn't a function then it's time to call on the closure.
791 fromVect ty expr
792   = identityConv ty >> return expr
793
794
795 -- TODO: What is this really doing?
796 toVect :: Type -> CoreExpr -> VM CoreExpr
797 toVect ty expr = identityConv ty >> return expr
798
799
800 -- | Check that we have the vectorised versions of all the
801 --   type constructors in this type.
802 identityConv :: Type -> VM ()
803 identityConv ty 
804   | Just ty' <- coreView ty 
805   = identityConv ty'
806
807 identityConv (TyConApp tycon tys)
808  = do mapM_ identityConv tys
809       identityConvTyCon tycon
810
811 identityConv _ = noV
812
813
814 -- | Check that we have the vectorised version of this type constructor.
815 identityConvTyCon :: TyCon -> VM ()
816 identityConvTyCon tc
817   | isBoxedTupleTyCon tc = return ()
818   | isUnLiftedTyCon tc   = return ()
819   | otherwise 
820   = do tc' <- maybeV (lookupTyCon tc)
821        if tc == tc' then return () else noV
822
823