Break out vectorisation of expressions into own module
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.Type
15 import Vectorise.Type.TyConDecl
16 import Vectorise.Type.Classify
17
18 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
19 import BasicTypes
20 import CoreSyn
21 import CoreUtils
22 import CoreUnfold
23 import MkCore            ( mkWildCase )
24 import BuildTyCl
25 import DataCon
26 import TyCon
27 import Type
28 import TypeRep
29 import Coercion
30 import FamInstEnv        ( FamInst, mkLocalFamInst )
31 import OccName
32 import Id
33 import MkId
34 import Var
35 import Name              ( Name, getOccName )
36 import NameEnv
37
38 import Unique
39 import UniqFM
40 import Util
41
42 import Outputable
43 import FastString
44
45 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
46 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
47 import Data.List
48
49 debug           = False
50 dtrace s x      = if debug then pprTrace "VectType" s x else x
51
52
53 -- ----------------------------------------------------------------------------
54 -- Type definitions
55
56
57 -- | Vectorise a type environment.
58 --   The type environment contains all the type things defined in a module.
59 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
60 vectTypeEnv env
61  = dtrace (ppr env)
62  $ do
63       cs <- readGEnv $ mk_map . global_tycons
64
65       -- Split the list of TyCons into the ones we have to vectorise vs the
66       -- ones we can pass through unchanged. We also pass through algebraic 
67       -- types that use non Haskell98 features, as we don't handle those.
68       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
69           keep_dcs             = concatMap tyConDataCons keep_tcs
70
71       zipWithM_ defTyCon   keep_tcs keep_tcs
72       zipWithM_ defDataCon keep_dcs keep_dcs
73
74       new_tcs <- vectTyConDecls conv_tcs
75
76       let orig_tcs = keep_tcs ++ conv_tcs
77
78       -- We don't need to make new representation types for dictionary
79       -- constructors. The constructors are always fully applied, and we don't 
80       -- need to lift them to arrays as a dictionary of a particular type
81       -- always has the same value.
82       let vect_tcs = filter (not . isClassTyCon) 
83                    $ keep_tcs ++ new_tcs
84
85       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
86         do
87           defTyConPAs (zipLazy vect_tcs dfuns')
88           reprs     <- mapM tyConRepr vect_tcs
89           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
90           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
91
92           dfuns     <- sequence 
93                     $  zipWith5 buildTyConBindings
94                                orig_tcs
95                                vect_tcs
96                                repr_tcs
97                                pdata_tcs
98                                reprs
99
100           binds     <- takeHoisted
101           return (dfuns, binds, repr_tcs ++ pdata_tcs)
102
103       let all_new_tcs = new_tcs ++ inst_tcs
104
105       let new_env = extendTypeEnvList env
106                        (map ATyCon all_new_tcs
107                         ++ [ADataCon dc | tc <- all_new_tcs
108                                         , dc <- tyConDataCons tc])
109
110       return (new_env, map mkLocalFamInst inst_tcs, binds)
111   where
112     tycons = typeEnvTyCons env
113     groups = tyConGroups tycons
114
115     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
116
117
118 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
119 mk_fam_inst fam_tc arg_tc
120   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
121
122
123 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
124 buildPReprTyCon orig_tc vect_tc repr
125   = do
126       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
127       -- rhs_ty   <- buildPReprType vect_tc
128       rhs_ty   <- sumReprType repr
129       prepr_tc <- builtin preprTyCon
130       liftDs $ buildSynTyCon name
131                              tyvars
132                              (SynonymTyCon rhs_ty)
133                              (typeKind rhs_ty)
134                              (Just $ mk_fam_inst prepr_tc vect_tc)
135   where
136     tyvars = tyConTyVars vect_tc
137
138 data CompRepr = Keep Type
139                      CoreExpr     -- PR dictionary for the type
140               | Wrap Type
141
142 data ProdRepr = EmptyProd
143               | UnaryProd CompRepr
144               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
145                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
146                      , repr_comp_tys :: [Type] -- representation types of
147                      , repr_comps    :: [CompRepr]          -- components
148                      }
149 data ConRepr  = ConRepr DataCon ProdRepr
150
151 data SumRepr  = EmptySum
152               | UnarySum ConRepr
153               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
154                      , repr_psum_tc  :: TyCon  -- PData representation tycon
155                      , repr_sel_ty   :: Type   -- type of selector
156                      , repr_con_tys :: [Type]  -- representation types of
157                      , repr_cons     :: [ConRepr]           -- components
158                      }
159
160 tyConRepr :: TyCon -> VM SumRepr
161 tyConRepr tc = sum_repr (tyConDataCons tc)
162   where
163     sum_repr []    = return EmptySum
164     sum_repr [con] = liftM UnarySum (con_repr con)
165     sum_repr cons  = do
166                        rs     <- mapM con_repr cons
167                        sum_tc <- builtin (sumTyCon arity)
168                        tys    <- mapM conReprType rs
169                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
170                        sel_ty <- builtin (selTy arity)
171                        return $ Sum { repr_sum_tc  = sum_tc
172                                     , repr_psum_tc = psum_tc
173                                     , repr_sel_ty  = sel_ty
174                                     , repr_con_tys = tys
175                                     , repr_cons    = rs
176                                     }
177       where
178         arity = length cons
179
180     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
181
182     prod_repr []   = return EmptyProd
183     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
184     prod_repr tys  = do
185                        rs <- mapM comp_repr tys
186                        tup_tc <- builtin (prodTyCon arity)
187                        tys'    <- mapM compReprType rs
188                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
189                        return $ Prod { repr_tup_tc   = tup_tc
190                                      , repr_ptup_tc  = ptup_tc
191                                      , repr_comp_tys = tys'
192                                      , repr_comps    = rs
193                                      }
194       where
195         arity = length tys
196     
197     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
198                    `orElseV` return (Wrap ty)
199
200 sumReprType :: SumRepr -> VM Type
201 sumReprType EmptySum = voidType
202 sumReprType (UnarySum r) = conReprType r
203 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
204   = return $ mkTyConApp sum_tc tys
205
206 conReprType :: ConRepr -> VM Type
207 conReprType (ConRepr _ r) = prodReprType r
208
209 prodReprType :: ProdRepr -> VM Type
210 prodReprType EmptyProd = voidType
211 prodReprType (UnaryProd r) = compReprType r
212 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
213   = return $ mkTyConApp tup_tc tys
214
215 compReprType :: CompRepr -> VM Type
216 compReprType (Keep ty _) = return ty
217 compReprType (Wrap ty) = do
218                              wrap_tc <- builtin wrapTyCon
219                              return $ mkTyConApp wrap_tc [ty]
220
221 compOrigType :: CompRepr -> Type
222 compOrigType (Keep ty _) = ty
223 compOrigType (Wrap ty) = ty
224
225 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
226 buildToPRepr vect_tc repr_tc _ repr
227   = do
228       let arg_ty = mkTyConApp vect_tc ty_args
229       res_ty <- mkPReprType arg_ty
230       arg    <- newLocalVar (fsLit "x") arg_ty
231       result <- to_sum (Var arg) arg_ty res_ty repr
232       return $ Lam arg result
233   where
234     ty_args = mkTyVarTys (tyConTyVars vect_tc)
235
236     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
237
238     to_sum _ _ _ EmptySum
239       = do
240           void <- builtin voidVar
241           return $ wrap_repr_inst $ Var void
242
243     to_sum arg arg_ty res_ty (UnarySum r)
244       = do
245           (pat, vars, body) <- con_alt r
246           return $ mkWildCase arg arg_ty res_ty
247                    [(pat, vars, wrap_repr_inst body)]
248
249     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
250                                   , repr_con_tys = tys
251                                   , repr_cons    =  cons })
252       = do
253           alts <- mapM con_alt cons
254           let alts' = [(pat, vars, wrap_repr_inst
255                                    $ mkConApp sum_con (map Type tys ++ [body]))
256                         | ((pat, vars, body), sum_con)
257                             <- zip alts (tyConDataCons sum_tc)]
258           return $ mkWildCase arg arg_ty res_ty alts'
259
260     con_alt (ConRepr con r)
261       = do
262           (vars, body) <- to_prod r
263           return (DataAlt con, vars, body)
264
265     to_prod EmptyProd
266       = do
267           void <- builtin voidVar
268           return ([], Var void)
269
270     to_prod (UnaryProd comp)
271       = do
272           var  <- newLocalVar (fsLit "x") (compOrigType comp)
273           body <- to_comp (Var var) comp
274           return ([var], body)
275
276     to_prod(Prod { repr_tup_tc   = tup_tc
277                  , repr_comp_tys = tys
278                  , repr_comps    = comps })
279       = do
280           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
281           exprs <- zipWithM to_comp (map Var vars) comps
282           return (vars, mkConApp tup_con (map Type tys ++ exprs))
283       where
284         [tup_con] = tyConDataCons tup_tc
285
286     to_comp expr (Keep _ _) = return expr
287     to_comp expr (Wrap ty)  = do
288                                 wrap_tc <- builtin wrapTyCon
289                                 return $ wrapNewTypeBody wrap_tc [ty] expr
290
291
292 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
293 buildFromPRepr vect_tc repr_tc _ repr
294   = do
295       arg_ty <- mkPReprType res_ty
296       arg <- newLocalVar (fsLit "x") arg_ty
297
298       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
299                          repr
300       return $ Lam arg result
301   where
302     ty_args = mkTyVarTys (tyConTyVars vect_tc)
303     res_ty  = mkTyConApp vect_tc ty_args
304
305     from_sum _ EmptySum
306       = do
307           dummy <- builtin fromVoidVar
308           return $ Var dummy `App` Type res_ty
309
310     from_sum expr (UnarySum r) = from_con expr r
311     from_sum expr (Sum { repr_sum_tc  = sum_tc
312                        , repr_con_tys = tys
313                        , repr_cons    = cons })
314       = do
315           vars  <- newLocalVars (fsLit "x") tys
316           es    <- zipWithM from_con (map Var vars) cons
317           return $ mkWildCase expr (exprType expr) res_ty
318                    [(DataAlt con, [var], e)
319                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
320
321     from_con expr (ConRepr con r)
322       = from_prod expr (mkConApp con $ map Type ty_args) r
323
324     from_prod _ con EmptyProd = return con
325     from_prod expr con (UnaryProd r)
326       = do
327           e <- from_comp expr r
328           return $ con `App` e
329      
330     from_prod expr con (Prod { repr_tup_tc   = tup_tc
331                              , repr_comp_tys = tys
332                              , repr_comps    = comps
333                              })
334       = do
335           vars <- newLocalVars (fsLit "y") tys
336           es   <- zipWithM from_comp (map Var vars) comps
337           return $ mkWildCase expr (exprType expr) res_ty
338                    [(DataAlt tup_con, vars, con `mkApps` es)]
339       where
340         [tup_con] = tyConDataCons tup_tc  
341
342     from_comp expr (Keep _ _) = return expr
343     from_comp expr (Wrap ty)
344       = do
345           wrap <- builtin wrapTyCon
346           return $ unwrapNewTypeBody wrap [ty] expr
347
348
349 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
350 buildToArrPRepr vect_tc prepr_tc pdata_tc r
351   = do
352       arg_ty <- mkPDataType el_ty
353       res_ty <- mkPDataType =<< mkPReprType el_ty
354       arg    <- newLocalVar (fsLit "xs") arg_ty
355
356       pdata_co <- mkBuiltinCo pdataTyCon
357       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
358           co           = mkAppCoercion pdata_co
359                        . mkSymCoercion
360                        $ mkTyConApp repr_co ty_args
361
362           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
363
364       (vars, result) <- to_sum r
365
366       return . Lam arg
367              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
368                [(DataAlt pdata_dc, vars, mkCoerce co result)]
369   where
370     ty_args = mkTyVarTys $ tyConTyVars vect_tc
371     el_ty   = mkTyConApp vect_tc ty_args
372
373     [pdata_dc] = tyConDataCons pdata_tc
374
375
376     to_sum EmptySum = do
377                         pvoid <- builtin pvoidVar
378                         return ([], Var pvoid)
379     to_sum (UnarySum r) = to_con r
380     to_sum (Sum { repr_psum_tc = psum_tc
381                 , repr_sel_ty  = sel_ty
382                 , repr_con_tys = tys
383                 , repr_cons    = cons
384                 })
385       = do
386           (vars, exprs) <- mapAndUnzipM to_con cons
387           sel <- newLocalVar (fsLit "sel") sel_ty
388           return (sel : concat vars, mk_result (Var sel) exprs)
389       where
390         [psum_con] = tyConDataCons psum_tc
391         mk_result sel exprs = wrapFamInstBody psum_tc tys
392                             $ mkConApp psum_con
393                             $ map Type tys ++ (sel : exprs)
394
395     to_con (ConRepr _ r) = to_prod r
396
397     to_prod EmptyProd = do
398                           pvoid <- builtin pvoidVar
399                           return ([], Var pvoid)
400     to_prod (UnaryProd r)
401       = do
402           pty  <- mkPDataType (compOrigType r)
403           var  <- newLocalVar (fsLit "x") pty
404           expr <- to_comp (Var var) r
405           return ([var], expr)
406
407     to_prod (Prod { repr_ptup_tc  = ptup_tc
408                   , repr_comp_tys = tys
409                   , repr_comps    = comps })
410       = do
411           ptys <- mapM (mkPDataType . compOrigType) comps
412           vars <- newLocalVars (fsLit "x") ptys
413           es   <- zipWithM to_comp (map Var vars) comps
414           return (vars, mk_result es)
415       where
416         [ptup_con] = tyConDataCons ptup_tc
417         mk_result exprs = wrapFamInstBody ptup_tc tys
418                         $ mkConApp ptup_con
419                         $ map Type tys ++ exprs
420
421     to_comp expr (Keep _ _) = return expr
422
423     -- FIXME: this is bound to be wrong!
424     to_comp expr (Wrap ty)
425       = do
426           wrap_tc  <- builtin wrapTyCon
427           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
428           return $ wrapNewTypeBody pwrap_tc [ty] expr
429
430
431 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
432 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
433   = do
434       arg_ty <- mkPDataType =<< mkPReprType el_ty
435       res_ty <- mkPDataType el_ty
436       arg    <- newLocalVar (fsLit "xs") arg_ty
437
438       pdata_co <- mkBuiltinCo pdataTyCon
439       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
440           co           = mkAppCoercion pdata_co
441                        $ mkTyConApp repr_co var_tys
442
443           scrut  = mkCoerce co (Var arg)
444
445           mk_result args = wrapFamInstBody pdata_tc var_tys
446                          $ mkConApp pdata_con
447                          $ map Type var_tys ++ args
448
449       (expr, _) <- fixV $ \ ~(_, args) ->
450                      from_sum res_ty (mk_result args) scrut r
451
452       return $ Lam arg expr
453     
454       -- (args, mk) <- from_sum res_ty scrut r
455       
456       -- let result = wrapFamInstBody pdata_tc var_tys
457       --           . mkConApp pdata_dc
458       --           $ map Type var_tys ++ args
459
460       -- return $ Lam arg (mk result)
461   where
462     var_tys = mkTyVarTys $ tyConTyVars vect_tc
463     el_ty   = mkTyConApp vect_tc var_tys
464
465     [pdata_con] = tyConDataCons pdata_tc
466
467     from_sum _ res _ EmptySum = return (res, [])
468     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
469     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
470                                   , repr_sel_ty  = sel_ty
471                                   , repr_con_tys = tys
472                                   , repr_cons    = cons })
473       = do
474           sel  <- newLocalVar (fsLit "sel") sel_ty
475           ptys <- mapM mkPDataType tys
476           vars <- newLocalVars (fsLit "xs") ptys
477           (res', args) <- fold from_con res_ty res (map Var vars) cons
478           let scrut = unwrapFamInstScrut psum_tc tys expr
479               body  = mkWildCase scrut (exprType scrut) res_ty
480                       [(DataAlt psum_con, sel : vars, res')]
481           return (body, Var sel : args)
482       where
483         [psum_con] = tyConDataCons psum_tc
484
485
486     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
487
488     from_prod _ res _ EmptyProd = return (res, [])
489     from_prod res_ty res expr (UnaryProd r)
490       = from_comp res_ty res expr r
491     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
492                                     , repr_comp_tys = tys
493                                     , repr_comps    = comps })
494       = do
495           ptys <- mapM mkPDataType tys
496           vars <- newLocalVars (fsLit "ys") ptys
497           (res', args) <- fold from_comp res_ty res (map Var vars) comps
498           let scrut = unwrapFamInstScrut ptup_tc tys expr
499               body  = mkWildCase scrut (exprType scrut) res_ty
500                       [(DataAlt ptup_con, vars, res')]
501           return (body, args)
502       where
503         [ptup_con] = tyConDataCons ptup_tc
504
505     from_comp _ res expr (Keep _ _) = return (res, [expr])
506     from_comp _ res expr (Wrap ty)
507       = do
508           wrap_tc  <- builtin wrapTyCon
509           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
510           return (res, [unwrapNewTypeBody pwrap_tc [ty]
511                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
512
513     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
514       where
515         f' (expr, r) (res, args) = do
516                                      (res', args') <- f res_ty res expr r
517                                      return (res', args' ++ args)
518
519 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
520 buildPRDict vect_tc prepr_tc _ r
521   = do
522       dict <- sum_dict r
523       pr_co <- mkBuiltinCo prTyCon
524       let co = mkAppCoercion pr_co
525              . mkSymCoercion
526              $ mkTyConApp arg_co ty_args
527       return (mkCoerce co dict)
528   where
529     ty_args = mkTyVarTys (tyConTyVars vect_tc)
530     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
531
532     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
533     sum_dict (UnarySum r) = con_dict r
534     sum_dict (Sum { repr_sum_tc  = sum_tc
535                   , repr_con_tys = tys
536                   , repr_cons    = cons
537                   })
538       = do
539           dicts <- mapM con_dict cons
540           dfun  <- prDFunOfTyCon sum_tc
541           return $ dfun `mkTyApps` tys `mkApps` dicts
542
543     con_dict (ConRepr _ r) = prod_dict r
544
545     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
546     prod_dict (UnaryProd r) = comp_dict r
547     prod_dict (Prod { repr_tup_tc   = tup_tc
548                     , repr_comp_tys = tys
549                     , repr_comps    = comps })
550       = do
551           dicts <- mapM comp_dict comps
552           dfun <- prDFunOfTyCon tup_tc
553           return $ dfun `mkTyApps` tys `mkApps` dicts
554
555     comp_dict (Keep _ pr) = return pr
556     comp_dict (Wrap ty)   = wrapPR ty
557
558
559 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
560 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
561   do
562     name' <- cloneName mkPDataTyConOcc orig_name
563     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
564     pdata <- builtin pdataTyCon
565
566     liftDs $ buildAlgTyCon name'
567                            tyvars
568                            []          -- no stupid theta
569                            rhs
570                            rec_flag    -- FIXME: is this ok?
571                            False       -- FIXME: no generics
572                            False       -- not GADT syntax
573                            (Just $ mk_fam_inst pdata vect_tc)
574   where
575     orig_name = tyConName orig_tc
576     tyvars = tyConTyVars vect_tc
577     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
578
579
580 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
581 buildPDataTyConRhs orig_name vect_tc repr_tc repr
582   = do
583       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
584       return $ DataTyCon { data_cons = [data_con], is_enum = False }
585
586 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
587 buildPDataDataCon orig_name vect_tc repr_tc repr
588   = do
589       dc_name  <- cloneName mkPDataDataConOcc orig_name
590       comp_tys <- sum_tys repr
591
592       liftDs $ buildDataCon dc_name
593                             False                  -- not infix
594                             (map (const HsNoBang) comp_tys)
595                             []                     -- no field labels
596                             tvs
597                             []                     -- no existentials
598                             []                     -- no eq spec
599                             []                     -- no context
600                             comp_tys
601                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
602                             repr_tc
603   where
604     tvs   = tyConTyVars vect_tc
605
606     sum_tys EmptySum = return []
607     sum_tys (UnarySum r) = con_tys r
608     sum_tys (Sum { repr_sel_ty = sel_ty
609                  , repr_cons   = cons })
610       = liftM (sel_ty :) (concatMapM con_tys cons)
611
612     con_tys (ConRepr _ r) = prod_tys r
613
614     prod_tys EmptyProd = return []
615     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
616     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
617
618     comp_ty r = mkPDataType (compOrigType r)
619
620
621 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
622                    -> VM Var
623 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
624   = do
625       vectDataConWorkers orig_tc vect_tc pdata_tc
626       buildPADict vect_tc prepr_tc pdata_tc repr
627
628 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
629 vectDataConWorkers orig_tc vect_tc arr_tc
630   = do
631       bs <- sequence
632           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
633           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
634                                  rep_tys
635                                  (inits rep_tys)
636                                  (tail $ tails rep_tys)
637       mapM_ (uncurry hoistBinding) bs
638   where
639     tyvars   = tyConTyVars vect_tc
640     var_tys  = mkTyVarTys tyvars
641     ty_args  = map Type var_tys
642     res_ty   = mkTyConApp vect_tc var_tys
643
644     cons     = tyConDataCons vect_tc
645     arity    = length cons
646     [arr_dc] = tyConDataCons arr_tc
647
648     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
649
650
651     mk_data_con con tys pre post
652       = liftM2 (,) (vect_data_con con)
653                    (lift_data_con tys pre post (mkDataConTag con))
654
655     sel_replicate len tag
656       | arity > 1 = do
657                       rep <- builtin (selReplicate arity)
658                       return [rep `mkApps` [len, tag]]
659
660       | otherwise = return []
661
662     vect_data_con con = return $ mkConApp con ty_args
663     lift_data_con tys pre_tys post_tys tag
664       = do
665           len  <- builtin liftingContext
666           args <- mapM (newLocalVar (fsLit "xs"))
667                   =<< mapM mkPDataType tys
668
669           sel  <- sel_replicate (Var len) tag
670
671           pre   <- mapM emptyPD (concat pre_tys)
672           post  <- mapM emptyPD (concat post_tys)
673
674           return . mkLams (len : args)
675                  . wrapFamInstBody arr_tc var_tys
676                  . mkConApp arr_dc
677                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
678
679     def_worker data_con arg_tys mk_body
680       = do
681           arity <- polyArity tyvars
682           body <- closedV
683                 . inBind orig_worker
684                 . polyAbstract tyvars $ \args ->
685                   liftM (mkLams (tyvars ++ args) . vectorised)
686                 $ buildClosures tyvars [] arg_tys res_ty mk_body
687
688           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
689           let vect_worker = raw_worker `setIdUnfolding`
690                               mkInlineRule body (Just arity)
691           defGlobalVar orig_worker vect_worker
692           return (vect_worker, body)
693       where
694         orig_worker = dataConWorkId data_con
695
696 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
697 buildPADict vect_tc prepr_tc arr_tc repr
698   = polyAbstract tvs $ \args ->
699     do
700       method_ids <- mapM (method args) paMethods
701
702       pa_tc  <- builtin paTyCon
703       pa_dc  <- builtin paDataCon
704       let dict = mkLams (tvs ++ args)
705                $ mkConApp pa_dc
706                $ Type inst_ty : map (method_call args) method_ids
707
708           dfun_ty = mkForAllTys tvs
709                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
710
711       raw_dfun <- newExportedVar dfun_name dfun_ty
712       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
713                           `setInlinePragma` dfunInlinePragma
714
715       hoistBinding dfun dict
716       return dfun
717   where
718     tvs = tyConTyVars vect_tc
719     arg_tys = mkTyVarTys tvs
720     inst_ty = mkTyConApp vect_tc arg_tys
721
722     dfun_name = mkPADFunOcc (getOccName vect_tc)
723
724     method args (name, build)
725       = localV
726       $ do
727           expr <- build vect_tc prepr_tc arr_tc repr
728           let body = mkLams (tvs ++ args) expr
729           raw_var <- newExportedVar (method_name name) (exprType body)
730           let var = raw_var
731                       `setIdUnfolding` mkInlineRule body (Just (length args))
732                       `setInlinePragma` alwaysInlinePragma
733           hoistBinding var body
734           return var
735
736     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
737
738     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
739
740
741 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
742 paMethods = [("dictPRepr",    buildPRDict),
743              ("toPRepr",      buildToPRepr),
744              ("fromPRepr",    buildFromPRepr),
745              ("toArrPRepr",   buildToArrPRepr),
746              ("fromArrPRepr", buildFromArrPRepr)]
747
748
749 -- ----------------------------------------------------------------------------
750 -- Conversions
751
752 -- | Build an expression that calls the vectorised version of some 
753 --   function from a `Closure`.
754 --
755 --   For example
756 --   @   
757 --      \(x :: Double) -> 
758 --      \(y :: Double) -> 
759 --      ($v_foo $: x) $: y
760 --   @
761 --
762 --   We use the type of the original binding to work out how many
763 --   outer lambdas to add.
764 --
765 fromVect 
766         :: Type         -- ^ The type of the original binding.
767         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
768         -> VM CoreExpr
769         
770 -- Convert the type to the core view if it isn't already.
771 fromVect ty expr 
772         | Just ty' <- coreView ty 
773         = fromVect ty' expr
774
775 -- For each function constructor in the original type we add an outer 
776 -- lambda to bind the parameter variable, and an inner application of it.
777 fromVect (FunTy arg_ty res_ty) expr
778   = do
779       arg     <- newLocalVar (fsLit "x") arg_ty
780       varg    <- toVect arg_ty (Var arg)
781       varg_ty <- vectType arg_ty
782       vres_ty <- vectType res_ty
783       apply   <- builtin applyVar
784       body    <- fromVect res_ty
785                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
786       return $ Lam arg body
787
788 -- If the type isn't a function then it's time to call on the closure.
789 fromVect ty expr
790   = identityConv ty >> return expr
791
792
793 -- TODO: What is this really doing?
794 toVect :: Type -> CoreExpr -> VM CoreExpr
795 toVect ty expr = identityConv ty >> return expr
796
797
798 -- | Check that we have the vectorised versions of all the
799 --   type constructors in this type.
800 identityConv :: Type -> VM ()
801 identityConv ty 
802   | Just ty' <- coreView ty 
803   = identityConv ty'
804
805 identityConv (TyConApp tycon tys)
806  = do mapM_ identityConv tys
807       identityConvTyCon tycon
808
809 identityConv _ = noV
810
811
812 -- | Check that we have the vectorised version of this type constructor.
813 identityConvTyCon :: TyCon -> VM ()
814 identityConvTyCon tc
815   | isBoxedTupleTyCon tc = return ()
816   | isUnLiftedTyCon tc   = return ()
817   | otherwise 
818   = do tc' <- maybeV (lookupTyCon tc)
819        if tc == tc' then return () else noV
820
821