Break out type vectorisation into own module
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14 import Vectorise.Type.Type
15
16 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
17 import BasicTypes
18 import CoreSyn
19 import CoreUtils
20 import CoreUnfold
21 import MkCore            ( mkWildCase )
22 import BuildTyCl
23 import DataCon
24 import TyCon
25 import Class
26 import Type
27 import TypeRep
28 import Coercion
29 import FamInstEnv        ( FamInst, mkLocalFamInst )
30 import OccName
31 import Id
32 import MkId
33 import Var
34 import Name              ( Name, getOccName )
35 import NameEnv
36
37 import Unique
38 import UniqFM
39 import UniqSet
40 import Util
41 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
42
43 import Outputable
44 import FastString
45
46 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
47 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
48 import Data.List
49
50 debug           = False
51 dtrace s x      = if debug then pprTrace "VectType" s x else x
52
53 -- ----------------------------------------------------------------------------
54 -- Types
55
56
57 -- ----------------------------------------------------------------------------
58 -- Type definitions
59
60 type TyConGroup = ([TyCon], UniqSet TyCon)
61
62 -- | Vectorise a type environment.
63 --   The type environment contains all the type things defined in a module.
64 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
65 vectTypeEnv env
66  = dtrace (ppr env)
67  $ do
68       cs <- readGEnv $ mk_map . global_tycons
69
70       -- Split the list of TyCons into the ones we have to vectorise vs the
71       -- ones we can pass through unchanged. We also pass through algebraic 
72       -- types that use non Haskell98 features, as we don't handle those.
73       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
74           keep_dcs             = concatMap tyConDataCons keep_tcs
75
76       dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
77
78       zipWithM_ defTyCon   keep_tcs keep_tcs
79       zipWithM_ defDataCon keep_dcs keep_dcs
80
81       new_tcs <- vectTyConDecls conv_tcs
82
83       dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
84
85       let orig_tcs = keep_tcs ++ conv_tcs
86
87       -- We don't need to make new representation types for dictionary
88       -- constructors. The constructors are always fully applied, and we don't 
89       -- need to lift them to arrays as a dictionary of a particular type
90       -- always has the same value.
91       let vect_tcs = filter (not . isClassTyCon) 
92                    $ keep_tcs ++ new_tcs
93
94       dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
95
96       mapM_ dumpTycon $ new_tcs
97
98
99       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
100         do
101           defTyConPAs (zipLazy vect_tcs dfuns')
102           reprs     <- mapM tyConRepr vect_tcs
103           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
104           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
105
106           dfuns     <- sequence 
107                     $  zipWith5 buildTyConBindings
108                                orig_tcs
109                                vect_tcs
110                                repr_tcs
111                                pdata_tcs
112                                reprs
113
114           binds     <- takeHoisted
115           return (dfuns, binds, repr_tcs ++ pdata_tcs)
116
117       let all_new_tcs = new_tcs ++ inst_tcs
118
119       let new_env = extendTypeEnvList env
120                        (map ATyCon all_new_tcs
121                         ++ [ADataCon dc | tc <- all_new_tcs
122                                         , dc <- tyConDataCons tc])
123
124       return (new_env, map mkLocalFamInst inst_tcs, binds)
125   where
126     tycons = typeEnvTyCons env
127     groups = tyConGroups tycons
128
129     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
130
131
132 -- | Vectorise some (possibly recursively defined) type constructors.
133 vectTyConDecls :: [TyCon] -> VM [TyCon]
134 vectTyConDecls tcs = fixV $ \tcs' ->
135   do
136     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
137     mapM vectTyConDecl tcs
138
139 dumpTycon :: TyCon -> VM ()
140 dumpTycon tycon
141         | Just cls      <- tyConClass_maybe tycon
142         = dtrace (vcat  [ ppr tycon
143                         , ppr [(m, varType m) | m <- classMethods cls ]])
144         $ return ()
145                 
146         | otherwise
147         = return ()
148
149
150 -- | Vectorise a single type construcrtor.
151 vectTyConDecl :: TyCon -> VM TyCon
152 vectTyConDecl tycon
153     -- a type class constructor.
154     -- TODO: check for no stupid theta, fds, assoc types. 
155     | isClassTyCon tycon
156     , Just cls          <- tyConClass_maybe tycon
157
158     = do    -- make the name of the vectorised class tycon.
159             name'       <- cloneName mkVectTyConOcc (tyConName tycon)
160
161             -- vectorise right of definition.
162             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
163
164             -- vectorise method selectors.
165             -- This also adds a mapping between the original and vectorised method selector
166             -- to the state.
167             methods'    <- mapM vectMethod
168                         $  [(id, defMethSpecOfDefMeth meth) 
169                                 | (id, meth)    <- classOpItems cls]
170
171             -- keep the original recursiveness flag.
172             let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
173         
174             -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
175             cls'     <- liftDs 
176                     $  buildClass
177                              False               -- include unfoldings on dictionary selectors.
178                              name'               -- new name  V_T:Class
179                              (tyConTyVars tycon) -- keep original type vars
180                              []                  -- no stupid theta
181                              []                  -- no functional dependencies
182                              []                  -- no associated types
183                              methods'            -- method info
184                              rec_flag            -- whether recursive
185
186             let tycon'  = mkClassTyCon name'
187                              (tyConKind tycon)
188                              (tyConTyVars tycon)
189                              rhs'
190                              cls'
191                              rec_flag
192
193             return $ tycon'
194                         
195     -- a regular algebraic type constructor.
196     -- TODO: check for stupid theta, generaics, GADTS etc
197     | isAlgTyCon tycon
198     = do    name'       <- cloneName mkVectTyConOcc (tyConName tycon)
199             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
200             let rec_flag =  boolToRecFlag (isRecursiveTyCon tycon)
201
202             liftDs $ buildAlgTyCon 
203                             name'               -- new name
204                             (tyConTyVars tycon) -- keep original type vars.
205                             []                  -- no stupid theta.
206                             rhs'                -- new constructor defs.
207                             rec_flag            -- FIXME: is this ok?
208                             False               -- FIXME: no generics
209                             False               -- not GADT syntax
210                             Nothing             -- not a family instance
211
212     -- some other crazy thing that we don't handle.
213     | otherwise
214     = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
215
216
217 -- | Vectorise a class method.
218 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
219 vectMethod (id, defMeth)
220  = do   
221         -- Vectorise the method type.
222         typ'    <- vectType (varType id)
223
224         -- Create a name for the vectorised method.
225         id'     <- cloneId mkVectOcc id typ'
226         defGlobalVar id id'
227
228         -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
229         -- to the types of each method. However, the types we get back from vectType
230         -- above already already have these, so we need to chop them off here otherwise
231         -- we'll get two copies in the final version.
232         let (_tyvars, tyBody) = splitForAllTys typ'
233         let (_dict,   tyRest) = splitFunTy tyBody
234
235         return  (Var.varName id', defMeth, tyRest)
236
237
238 -- | Vectorise the RHS of an algebraic type.
239 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
240 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
241                              , is_enum   = is_enum
242                              })
243   = do
244       data_cons' <- mapM vectDataCon data_cons
245       zipWithM_ defDataCon data_cons data_cons'
246       return $ DataTyCon { data_cons = data_cons'
247                          , is_enum   = is_enum
248                          }
249
250 vectAlgTyConRhs tc _ 
251         = cantVectorise "Can't vectorise type definition:" (ppr tc)
252
253
254 -- | Vectorise a data constructor.
255 --   Vectorises its argument and return types.
256 vectDataCon :: DataCon -> VM DataCon
257 vectDataCon dc
258   | not . null $ dataConExTyVars dc
259   = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
260
261   | not . null $ dataConEqSpec   dc
262   = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
263
264   | otherwise
265   = do
266       name'    <- cloneName mkVectDataConOcc name
267       tycon'   <- vectTyCon tycon
268       arg_tys  <- mapM vectType rep_arg_tys
269
270       liftDs $ buildDataCon 
271                 name'
272                 False                          -- not infix
273                 (map (const HsNoBang) arg_tys) -- strictness annots on args.
274                 []                             -- no labelled fields
275                 univ_tvs                       -- universally quantified vars
276                 []                             -- no existential tvs for now
277                 []                             -- no eq spec for now
278                 []                             -- no context
279                 arg_tys                        -- argument types
280                 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
281                 tycon'                         -- representation tycon
282   where
283     name        = dataConName dc
284     univ_tvs    = dataConUnivTyVars dc
285     rep_arg_tys = dataConRepArgTys dc
286     tycon       = dataConTyCon dc
287
288 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
289 mk_fam_inst fam_tc arg_tc
290   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
291
292
293 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
294 buildPReprTyCon orig_tc vect_tc repr
295   = do
296       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
297       -- rhs_ty   <- buildPReprType vect_tc
298       rhs_ty   <- sumReprType repr
299       prepr_tc <- builtin preprTyCon
300       liftDs $ buildSynTyCon name
301                              tyvars
302                              (SynonymTyCon rhs_ty)
303                              (typeKind rhs_ty)
304                              (Just $ mk_fam_inst prepr_tc vect_tc)
305   where
306     tyvars = tyConTyVars vect_tc
307
308 data CompRepr = Keep Type
309                      CoreExpr     -- PR dictionary for the type
310               | Wrap Type
311
312 data ProdRepr = EmptyProd
313               | UnaryProd CompRepr
314               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
315                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
316                      , repr_comp_tys :: [Type] -- representation types of
317                      , repr_comps    :: [CompRepr]          -- components
318                      }
319 data ConRepr  = ConRepr DataCon ProdRepr
320
321 data SumRepr  = EmptySum
322               | UnarySum ConRepr
323               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
324                      , repr_psum_tc  :: TyCon  -- PData representation tycon
325                      , repr_sel_ty   :: Type   -- type of selector
326                      , repr_con_tys :: [Type]  -- representation types of
327                      , repr_cons     :: [ConRepr]           -- components
328                      }
329
330 tyConRepr :: TyCon -> VM SumRepr
331 tyConRepr tc = sum_repr (tyConDataCons tc)
332   where
333     sum_repr []    = return EmptySum
334     sum_repr [con] = liftM UnarySum (con_repr con)
335     sum_repr cons  = do
336                        rs     <- mapM con_repr cons
337                        sum_tc <- builtin (sumTyCon arity)
338                        tys    <- mapM conReprType rs
339                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
340                        sel_ty <- builtin (selTy arity)
341                        return $ Sum { repr_sum_tc  = sum_tc
342                                     , repr_psum_tc = psum_tc
343                                     , repr_sel_ty  = sel_ty
344                                     , repr_con_tys = tys
345                                     , repr_cons    = rs
346                                     }
347       where
348         arity = length cons
349
350     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
351
352     prod_repr []   = return EmptyProd
353     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
354     prod_repr tys  = do
355                        rs <- mapM comp_repr tys
356                        tup_tc <- builtin (prodTyCon arity)
357                        tys'    <- mapM compReprType rs
358                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
359                        return $ Prod { repr_tup_tc   = tup_tc
360                                      , repr_ptup_tc  = ptup_tc
361                                      , repr_comp_tys = tys'
362                                      , repr_comps    = rs
363                                      }
364       where
365         arity = length tys
366     
367     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
368                    `orElseV` return (Wrap ty)
369
370 sumReprType :: SumRepr -> VM Type
371 sumReprType EmptySum = voidType
372 sumReprType (UnarySum r) = conReprType r
373 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
374   = return $ mkTyConApp sum_tc tys
375
376 conReprType :: ConRepr -> VM Type
377 conReprType (ConRepr _ r) = prodReprType r
378
379 prodReprType :: ProdRepr -> VM Type
380 prodReprType EmptyProd = voidType
381 prodReprType (UnaryProd r) = compReprType r
382 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
383   = return $ mkTyConApp tup_tc tys
384
385 compReprType :: CompRepr -> VM Type
386 compReprType (Keep ty _) = return ty
387 compReprType (Wrap ty) = do
388                              wrap_tc <- builtin wrapTyCon
389                              return $ mkTyConApp wrap_tc [ty]
390
391 compOrigType :: CompRepr -> Type
392 compOrigType (Keep ty _) = ty
393 compOrigType (Wrap ty) = ty
394
395 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
396 buildToPRepr vect_tc repr_tc _ repr
397   = do
398       let arg_ty = mkTyConApp vect_tc ty_args
399       res_ty <- mkPReprType arg_ty
400       arg    <- newLocalVar (fsLit "x") arg_ty
401       result <- to_sum (Var arg) arg_ty res_ty repr
402       return $ Lam arg result
403   where
404     ty_args = mkTyVarTys (tyConTyVars vect_tc)
405
406     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
407
408     to_sum _ _ _ EmptySum
409       = do
410           void <- builtin voidVar
411           return $ wrap_repr_inst $ Var void
412
413     to_sum arg arg_ty res_ty (UnarySum r)
414       = do
415           (pat, vars, body) <- con_alt r
416           return $ mkWildCase arg arg_ty res_ty
417                    [(pat, vars, wrap_repr_inst body)]
418
419     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
420                                   , repr_con_tys = tys
421                                   , repr_cons    =  cons })
422       = do
423           alts <- mapM con_alt cons
424           let alts' = [(pat, vars, wrap_repr_inst
425                                    $ mkConApp sum_con (map Type tys ++ [body]))
426                         | ((pat, vars, body), sum_con)
427                             <- zip alts (tyConDataCons sum_tc)]
428           return $ mkWildCase arg arg_ty res_ty alts'
429
430     con_alt (ConRepr con r)
431       = do
432           (vars, body) <- to_prod r
433           return (DataAlt con, vars, body)
434
435     to_prod EmptyProd
436       = do
437           void <- builtin voidVar
438           return ([], Var void)
439
440     to_prod (UnaryProd comp)
441       = do
442           var  <- newLocalVar (fsLit "x") (compOrigType comp)
443           body <- to_comp (Var var) comp
444           return ([var], body)
445
446     to_prod(Prod { repr_tup_tc   = tup_tc
447                  , repr_comp_tys = tys
448                  , repr_comps    = comps })
449       = do
450           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
451           exprs <- zipWithM to_comp (map Var vars) comps
452           return (vars, mkConApp tup_con (map Type tys ++ exprs))
453       where
454         [tup_con] = tyConDataCons tup_tc
455
456     to_comp expr (Keep _ _) = return expr
457     to_comp expr (Wrap ty)  = do
458                                 wrap_tc <- builtin wrapTyCon
459                                 return $ wrapNewTypeBody wrap_tc [ty] expr
460
461
462 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
463 buildFromPRepr vect_tc repr_tc _ repr
464   = do
465       arg_ty <- mkPReprType res_ty
466       arg <- newLocalVar (fsLit "x") arg_ty
467
468       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
469                          repr
470       return $ Lam arg result
471   where
472     ty_args = mkTyVarTys (tyConTyVars vect_tc)
473     res_ty  = mkTyConApp vect_tc ty_args
474
475     from_sum _ EmptySum
476       = do
477           dummy <- builtin fromVoidVar
478           return $ Var dummy `App` Type res_ty
479
480     from_sum expr (UnarySum r) = from_con expr r
481     from_sum expr (Sum { repr_sum_tc  = sum_tc
482                        , repr_con_tys = tys
483                        , repr_cons    = cons })
484       = do
485           vars  <- newLocalVars (fsLit "x") tys
486           es    <- zipWithM from_con (map Var vars) cons
487           return $ mkWildCase expr (exprType expr) res_ty
488                    [(DataAlt con, [var], e)
489                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
490
491     from_con expr (ConRepr con r)
492       = from_prod expr (mkConApp con $ map Type ty_args) r
493
494     from_prod _ con EmptyProd = return con
495     from_prod expr con (UnaryProd r)
496       = do
497           e <- from_comp expr r
498           return $ con `App` e
499      
500     from_prod expr con (Prod { repr_tup_tc   = tup_tc
501                              , repr_comp_tys = tys
502                              , repr_comps    = comps
503                              })
504       = do
505           vars <- newLocalVars (fsLit "y") tys
506           es   <- zipWithM from_comp (map Var vars) comps
507           return $ mkWildCase expr (exprType expr) res_ty
508                    [(DataAlt tup_con, vars, con `mkApps` es)]
509       where
510         [tup_con] = tyConDataCons tup_tc  
511
512     from_comp expr (Keep _ _) = return expr
513     from_comp expr (Wrap ty)
514       = do
515           wrap <- builtin wrapTyCon
516           return $ unwrapNewTypeBody wrap [ty] expr
517
518
519 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
520 buildToArrPRepr vect_tc prepr_tc pdata_tc r
521   = do
522       arg_ty <- mkPDataType el_ty
523       res_ty <- mkPDataType =<< mkPReprType el_ty
524       arg    <- newLocalVar (fsLit "xs") arg_ty
525
526       pdata_co <- mkBuiltinCo pdataTyCon
527       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
528           co           = mkAppCoercion pdata_co
529                        . mkSymCoercion
530                        $ mkTyConApp repr_co ty_args
531
532           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
533
534       (vars, result) <- to_sum r
535
536       return . Lam arg
537              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
538                [(DataAlt pdata_dc, vars, mkCoerce co result)]
539   where
540     ty_args = mkTyVarTys $ tyConTyVars vect_tc
541     el_ty   = mkTyConApp vect_tc ty_args
542
543     [pdata_dc] = tyConDataCons pdata_tc
544
545
546     to_sum EmptySum = do
547                         pvoid <- builtin pvoidVar
548                         return ([], Var pvoid)
549     to_sum (UnarySum r) = to_con r
550     to_sum (Sum { repr_psum_tc = psum_tc
551                 , repr_sel_ty  = sel_ty
552                 , repr_con_tys = tys
553                 , repr_cons    = cons
554                 })
555       = do
556           (vars, exprs) <- mapAndUnzipM to_con cons
557           sel <- newLocalVar (fsLit "sel") sel_ty
558           return (sel : concat vars, mk_result (Var sel) exprs)
559       where
560         [psum_con] = tyConDataCons psum_tc
561         mk_result sel exprs = wrapFamInstBody psum_tc tys
562                             $ mkConApp psum_con
563                             $ map Type tys ++ (sel : exprs)
564
565     to_con (ConRepr _ r) = to_prod r
566
567     to_prod EmptyProd = do
568                           pvoid <- builtin pvoidVar
569                           return ([], Var pvoid)
570     to_prod (UnaryProd r)
571       = do
572           pty  <- mkPDataType (compOrigType r)
573           var  <- newLocalVar (fsLit "x") pty
574           expr <- to_comp (Var var) r
575           return ([var], expr)
576
577     to_prod (Prod { repr_ptup_tc  = ptup_tc
578                   , repr_comp_tys = tys
579                   , repr_comps    = comps })
580       = do
581           ptys <- mapM (mkPDataType . compOrigType) comps
582           vars <- newLocalVars (fsLit "x") ptys
583           es   <- zipWithM to_comp (map Var vars) comps
584           return (vars, mk_result es)
585       where
586         [ptup_con] = tyConDataCons ptup_tc
587         mk_result exprs = wrapFamInstBody ptup_tc tys
588                         $ mkConApp ptup_con
589                         $ map Type tys ++ exprs
590
591     to_comp expr (Keep _ _) = return expr
592
593     -- FIXME: this is bound to be wrong!
594     to_comp expr (Wrap ty)
595       = do
596           wrap_tc  <- builtin wrapTyCon
597           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
598           return $ wrapNewTypeBody pwrap_tc [ty] expr
599
600
601 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
602 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
603   = do
604       arg_ty <- mkPDataType =<< mkPReprType el_ty
605       res_ty <- mkPDataType el_ty
606       arg    <- newLocalVar (fsLit "xs") arg_ty
607
608       pdata_co <- mkBuiltinCo pdataTyCon
609       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
610           co           = mkAppCoercion pdata_co
611                        $ mkTyConApp repr_co var_tys
612
613           scrut  = mkCoerce co (Var arg)
614
615           mk_result args = wrapFamInstBody pdata_tc var_tys
616                          $ mkConApp pdata_con
617                          $ map Type var_tys ++ args
618
619       (expr, _) <- fixV $ \ ~(_, args) ->
620                      from_sum res_ty (mk_result args) scrut r
621
622       return $ Lam arg expr
623     
624       -- (args, mk) <- from_sum res_ty scrut r
625       
626       -- let result = wrapFamInstBody pdata_tc var_tys
627       --           . mkConApp pdata_dc
628       --           $ map Type var_tys ++ args
629
630       -- return $ Lam arg (mk result)
631   where
632     var_tys = mkTyVarTys $ tyConTyVars vect_tc
633     el_ty   = mkTyConApp vect_tc var_tys
634
635     [pdata_con] = tyConDataCons pdata_tc
636
637     from_sum _ res _ EmptySum = return (res, [])
638     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
639     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
640                                   , repr_sel_ty  = sel_ty
641                                   , repr_con_tys = tys
642                                   , repr_cons    = cons })
643       = do
644           sel  <- newLocalVar (fsLit "sel") sel_ty
645           ptys <- mapM mkPDataType tys
646           vars <- newLocalVars (fsLit "xs") ptys
647           (res', args) <- fold from_con res_ty res (map Var vars) cons
648           let scrut = unwrapFamInstScrut psum_tc tys expr
649               body  = mkWildCase scrut (exprType scrut) res_ty
650                       [(DataAlt psum_con, sel : vars, res')]
651           return (body, Var sel : args)
652       where
653         [psum_con] = tyConDataCons psum_tc
654
655
656     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
657
658     from_prod _ res _ EmptyProd = return (res, [])
659     from_prod res_ty res expr (UnaryProd r)
660       = from_comp res_ty res expr r
661     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
662                                     , repr_comp_tys = tys
663                                     , repr_comps    = comps })
664       = do
665           ptys <- mapM mkPDataType tys
666           vars <- newLocalVars (fsLit "ys") ptys
667           (res', args) <- fold from_comp res_ty res (map Var vars) comps
668           let scrut = unwrapFamInstScrut ptup_tc tys expr
669               body  = mkWildCase scrut (exprType scrut) res_ty
670                       [(DataAlt ptup_con, vars, res')]
671           return (body, args)
672       where
673         [ptup_con] = tyConDataCons ptup_tc
674
675     from_comp _ res expr (Keep _ _) = return (res, [expr])
676     from_comp _ res expr (Wrap ty)
677       = do
678           wrap_tc  <- builtin wrapTyCon
679           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
680           return (res, [unwrapNewTypeBody pwrap_tc [ty]
681                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
682
683     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
684       where
685         f' (expr, r) (res, args) = do
686                                      (res', args') <- f res_ty res expr r
687                                      return (res', args' ++ args)
688
689 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
690 buildPRDict vect_tc prepr_tc _ r
691   = do
692       dict <- sum_dict r
693       pr_co <- mkBuiltinCo prTyCon
694       let co = mkAppCoercion pr_co
695              . mkSymCoercion
696              $ mkTyConApp arg_co ty_args
697       return (mkCoerce co dict)
698   where
699     ty_args = mkTyVarTys (tyConTyVars vect_tc)
700     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
701
702     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
703     sum_dict (UnarySum r) = con_dict r
704     sum_dict (Sum { repr_sum_tc  = sum_tc
705                   , repr_con_tys = tys
706                   , repr_cons    = cons
707                   })
708       = do
709           dicts <- mapM con_dict cons
710           dfun  <- prDFunOfTyCon sum_tc
711           return $ dfun `mkTyApps` tys `mkApps` dicts
712
713     con_dict (ConRepr _ r) = prod_dict r
714
715     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
716     prod_dict (UnaryProd r) = comp_dict r
717     prod_dict (Prod { repr_tup_tc   = tup_tc
718                     , repr_comp_tys = tys
719                     , repr_comps    = comps })
720       = do
721           dicts <- mapM comp_dict comps
722           dfun <- prDFunOfTyCon tup_tc
723           return $ dfun `mkTyApps` tys `mkApps` dicts
724
725     comp_dict (Keep _ pr) = return pr
726     comp_dict (Wrap ty)   = wrapPR ty
727
728
729 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
730 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
731   do
732     name' <- cloneName mkPDataTyConOcc orig_name
733     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
734     pdata <- builtin pdataTyCon
735
736     liftDs $ buildAlgTyCon name'
737                            tyvars
738                            []          -- no stupid theta
739                            rhs
740                            rec_flag    -- FIXME: is this ok?
741                            False       -- FIXME: no generics
742                            False       -- not GADT syntax
743                            (Just $ mk_fam_inst pdata vect_tc)
744   where
745     orig_name = tyConName orig_tc
746     tyvars = tyConTyVars vect_tc
747     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
748
749
750 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
751 buildPDataTyConRhs orig_name vect_tc repr_tc repr
752   = do
753       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
754       return $ DataTyCon { data_cons = [data_con], is_enum = False }
755
756 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
757 buildPDataDataCon orig_name vect_tc repr_tc repr
758   = do
759       dc_name  <- cloneName mkPDataDataConOcc orig_name
760       comp_tys <- sum_tys repr
761
762       liftDs $ buildDataCon dc_name
763                             False                  -- not infix
764                             (map (const HsNoBang) comp_tys)
765                             []                     -- no field labels
766                             tvs
767                             []                     -- no existentials
768                             []                     -- no eq spec
769                             []                     -- no context
770                             comp_tys
771                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
772                             repr_tc
773   where
774     tvs   = tyConTyVars vect_tc
775
776     sum_tys EmptySum = return []
777     sum_tys (UnarySum r) = con_tys r
778     sum_tys (Sum { repr_sel_ty = sel_ty
779                  , repr_cons   = cons })
780       = liftM (sel_ty :) (concatMapM con_tys cons)
781
782     con_tys (ConRepr _ r) = prod_tys r
783
784     prod_tys EmptyProd = return []
785     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
786     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
787
788     comp_ty r = mkPDataType (compOrigType r)
789
790
791 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
792                    -> VM Var
793 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
794   = do
795       vectDataConWorkers orig_tc vect_tc pdata_tc
796       buildPADict vect_tc prepr_tc pdata_tc repr
797
798 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
799 vectDataConWorkers orig_tc vect_tc arr_tc
800   = do
801       bs <- sequence
802           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
803           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
804                                  rep_tys
805                                  (inits rep_tys)
806                                  (tail $ tails rep_tys)
807       mapM_ (uncurry hoistBinding) bs
808   where
809     tyvars   = tyConTyVars vect_tc
810     var_tys  = mkTyVarTys tyvars
811     ty_args  = map Type var_tys
812     res_ty   = mkTyConApp vect_tc var_tys
813
814     cons     = tyConDataCons vect_tc
815     arity    = length cons
816     [arr_dc] = tyConDataCons arr_tc
817
818     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
819
820
821     mk_data_con con tys pre post
822       = liftM2 (,) (vect_data_con con)
823                    (lift_data_con tys pre post (mkDataConTag con))
824
825     sel_replicate len tag
826       | arity > 1 = do
827                       rep <- builtin (selReplicate arity)
828                       return [rep `mkApps` [len, tag]]
829
830       | otherwise = return []
831
832     vect_data_con con = return $ mkConApp con ty_args
833     lift_data_con tys pre_tys post_tys tag
834       = do
835           len  <- builtin liftingContext
836           args <- mapM (newLocalVar (fsLit "xs"))
837                   =<< mapM mkPDataType tys
838
839           sel  <- sel_replicate (Var len) tag
840
841           pre   <- mapM emptyPD (concat pre_tys)
842           post  <- mapM emptyPD (concat post_tys)
843
844           return . mkLams (len : args)
845                  . wrapFamInstBody arr_tc var_tys
846                  . mkConApp arr_dc
847                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
848
849     def_worker data_con arg_tys mk_body
850       = do
851           arity <- polyArity tyvars
852           body <- closedV
853                 . inBind orig_worker
854                 . polyAbstract tyvars $ \args ->
855                   liftM (mkLams (tyvars ++ args) . vectorised)
856                 $ buildClosures tyvars [] arg_tys res_ty mk_body
857
858           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
859           let vect_worker = raw_worker `setIdUnfolding`
860                               mkInlineRule body (Just arity)
861           defGlobalVar orig_worker vect_worker
862           return (vect_worker, body)
863       where
864         orig_worker = dataConWorkId data_con
865
866 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
867 buildPADict vect_tc prepr_tc arr_tc repr
868   = polyAbstract tvs $ \args ->
869     do
870       method_ids <- mapM (method args) paMethods
871
872       pa_tc  <- builtin paTyCon
873       pa_dc  <- builtin paDataCon
874       let dict = mkLams (tvs ++ args)
875                $ mkConApp pa_dc
876                $ Type inst_ty : map (method_call args) method_ids
877
878           dfun_ty = mkForAllTys tvs
879                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
880
881       raw_dfun <- newExportedVar dfun_name dfun_ty
882       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
883                           `setInlinePragma` dfunInlinePragma
884
885       hoistBinding dfun dict
886       return dfun
887   where
888     tvs = tyConTyVars vect_tc
889     arg_tys = mkTyVarTys tvs
890     inst_ty = mkTyConApp vect_tc arg_tys
891
892     dfun_name = mkPADFunOcc (getOccName vect_tc)
893
894     method args (name, build)
895       = localV
896       $ do
897           expr <- build vect_tc prepr_tc arr_tc repr
898           let body = mkLams (tvs ++ args) expr
899           raw_var <- newExportedVar (method_name name) (exprType body)
900           let var = raw_var
901                       `setIdUnfolding` mkInlineRule body (Just (length args))
902                       `setInlinePragma` alwaysInlinePragma
903           hoistBinding var body
904           return var
905
906     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
907
908     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
909
910
911 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
912 paMethods = [("dictPRepr",    buildPRDict),
913              ("toPRepr",      buildToPRepr),
914              ("fromPRepr",    buildFromPRepr),
915              ("toArrPRepr",   buildToArrPRepr),
916              ("fromArrPRepr", buildFromArrPRepr)]
917
918
919 -- | Split the given tycons into two sets depending on whether they have to be
920 --   converted (first list) or not (second list). The first argument contains
921 --   information about the conversion status of external tycons:
922 --
923 --   * tycons which have converted versions are mapped to True
924 --   * tycons which are not changed by vectorisation are mapped to False
925 --   * tycons which can't be converted are not elements of the map
926 --
927 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
928 classifyTyCons = classify [] []
929   where
930     classify conv keep _  [] = (conv, keep)
931     classify conv keep cs ((tcs, ds) : rs)
932       | can_convert && must_convert
933         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
934       | can_convert
935         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
936       | otherwise
937         = classify conv keep cs rs
938       where
939         refs = ds `delListFromUniqSet` tcs
940
941         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
942         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
943
944         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
945
946 -- | Compute mutually recursive groups of tycons in topological order
947 --
948 tyConGroups :: [TyCon] -> [TyConGroup]
949 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
950   where
951     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
952                                 , let ds = tyConsOfTyCon tc]
953
954     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
955     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
956       where
957         (tcs, dss) = unzip els
958
959 tyConsOfTyCon :: TyCon -> UniqSet TyCon
960 tyConsOfTyCon
961   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
962
963 tyConsOfType :: Type -> UniqSet TyCon
964 tyConsOfType ty
965   | Just ty' <- coreView ty    = tyConsOfType ty'
966 tyConsOfType (TyVarTy _)       = emptyUniqSet
967 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
968   where
969     extend | isUnLiftedTyCon tc
970            || isTupleTyCon   tc = id
971
972            | otherwise          = (`addOneToUniqSet` tc)
973
974 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
975 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
976                                  `addOneToUniqSet` funTyCon
977 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
978 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
979
980 tyConsOfTypes :: [Type] -> UniqSet TyCon
981 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
982
983
984 -- ----------------------------------------------------------------------------
985 -- Conversions
986
987 -- | Build an expression that calls the vectorised version of some 
988 --   function from a `Closure`.
989 --
990 --   For example
991 --   @   
992 --      \(x :: Double) -> 
993 --      \(y :: Double) -> 
994 --      ($v_foo $: x) $: y
995 --   @
996 --
997 --   We use the type of the original binding to work out how many
998 --   outer lambdas to add.
999 --
1000 fromVect 
1001         :: Type         -- ^ The type of the original binding.
1002         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
1003         -> VM CoreExpr
1004         
1005 -- Convert the type to the core view if it isn't already.
1006 fromVect ty expr 
1007         | Just ty' <- coreView ty 
1008         = fromVect ty' expr
1009
1010 -- For each function constructor in the original type we add an outer 
1011 -- lambda to bind the parameter variable, and an inner application of it.
1012 fromVect (FunTy arg_ty res_ty) expr
1013   = do
1014       arg     <- newLocalVar (fsLit "x") arg_ty
1015       varg    <- toVect arg_ty (Var arg)
1016       varg_ty <- vectType arg_ty
1017       vres_ty <- vectType res_ty
1018       apply   <- builtin applyVar
1019       body    <- fromVect res_ty
1020                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1021       return $ Lam arg body
1022
1023 -- If the type isn't a function then it's time to call on the closure.
1024 fromVect ty expr
1025   = identityConv ty >> return expr
1026
1027
1028 toVect :: Type -> CoreExpr -> VM CoreExpr
1029 toVect ty expr = identityConv ty >> return expr
1030
1031
1032 identityConv :: Type -> VM ()
1033 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1034 identityConv (TyConApp tycon tys)
1035   = do
1036       mapM_ identityConv tys
1037       identityConvTyCon tycon
1038 identityConv _ = noV
1039
1040 identityConvTyCon :: TyCon -> VM ()
1041 identityConvTyCon tc
1042   | isBoxedTupleTyCon tc = return ()
1043   | isUnLiftedTyCon tc   = return ()
1044   | otherwise            = do
1045                              tc' <- maybeV (lookupTyCon tc)
1046                              if tc == tc' then return () else noV
1047