Vectorisation of method types
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectMonad
10 import VectUtils
11 import VectCore
12
13 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
14 import BasicTypes
15 import CoreSyn
16 import CoreUtils
17 import CoreUnfold
18 import MkCore            ( mkWildCase )
19 import BuildTyCl
20 import DataCon
21 import TyCon
22 import Class
23 import Type
24 import TypeRep
25 import Coercion
26 import FamInstEnv        ( FamInst, mkLocalFamInst )
27 import OccName
28 import Id
29 import MkId
30 import Var               ( Var, TyVar, varType, varName )
31 import Name              ( Name, getOccName )
32 import NameEnv
33
34 import Unique
35 import UniqFM
36 import UniqSet
37 import Util
38 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
39
40 import Outputable
41 import FastString
42
43 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
44 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
45 import Data.List
46 import Data.Maybe
47
48 debug           = False
49 dtrace s x      = if debug then pprTrace "VectType" s x else x
50
51 -- ----------------------------------------------------------------------------
52 -- Types
53
54 -- | Vectorise a type constructor.
55 vectTyCon :: TyCon -> VM TyCon
56 vectTyCon tc
57   | isFunTyCon tc        = builtin closureTyCon
58   | isBoxedTupleTyCon tc = return tc
59   | isUnLiftedTyCon tc   = return tc
60   | otherwise            
61   = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
62         $ lookupTyCon tc
63
64
65 vectAndLiftType :: Type -> VM (Type, Type)
66 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
67 vectAndLiftType ty
68   = do
69       mdicts   <- mapM paDictArgType tyvars
70       let dicts = [dict | Just dict <- mdicts]
71       vmono_ty <- vectType mono_ty
72       lmono_ty <- mkPDataType vmono_ty
73       return (abstractType tyvars dicts vmono_ty,
74               abstractType tyvars dicts lmono_ty)
75   where
76     (tyvars, mono_ty) = splitForAllTys ty
77
78
79 -- | Vectorise a type.
80 vectType :: Type -> VM Type
81 vectType ty
82         | Just ty'      <- coreView ty
83         = vectType ty'
84         
85 vectType (TyVarTy tv)           = return $ TyVarTy tv
86 vectType (AppTy ty1 ty2)        = liftM2 AppTy    (vectType ty1) (vectType ty2)
87 vectType (TyConApp tc tys)      = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
88 vectType (FunTy ty1 ty2)        = liftM2 TyConApp (builtin closureTyCon)
89                                                   (mapM vectAndBoxType [ty1,ty2])
90
91 -- For each quantified var we need to add a PA dictionary out the front of the type.
92 -- So          forall a. C  a => a -> a   
93 -- turns into  forall a. Cv a => PA a => a :-> a
94 vectType ty@(ForAllTy _ _)
95  = do
96       -- split the type into the quantified vars, its dictionaries and the body.
97       let (tyvars, tyBody)   = splitForAllTys ty
98       let (tyArgs, tyResult) = splitFunTys    tyBody
99
100       let (tyArgs_dict, tyArgs_regular) 
101                   = partition isDictType tyArgs
102
103       -- vectorise the body.
104       let tyBody' = mkFunTys tyArgs_regular tyResult
105       tyBody''    <- vectType tyBody'
106
107       -- vectorise the dictionary parameters.
108       dictsVect   <- mapM vectType tyArgs_dict
109
110       -- make a PA dictionary for each of the type variables.
111       dictsPA     <- liftM catMaybes $ mapM paDictArgType tyvars
112
113       -- pack it all back together.
114       return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
115
116 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
117
118
119 -- | Add quantified vars and dictionary parameters to the front of a type.
120 abstractType :: [TyVar] -> [Type] -> Type -> Type
121 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
122
123
124 -- | Check if some type is a type class dictionary.
125 isDictType :: Type -> Bool
126 isDictType ty
127  = case splitTyConApp_maybe ty of
128         Just (tyCon, _)         -> isClassTyCon tyCon
129         _                       -> False
130
131         
132 -- ----------------------------------------------------------------------------
133 -- Boxing
134
135 boxType :: Type -> VM Type
136 boxType ty
137   | Just (tycon, []) <- splitTyConApp_maybe ty
138   , isUnLiftedTyCon tycon
139   = do
140       r <- lookupBoxedTyCon tycon
141       case r of
142         Just tycon' -> return $ mkTyConApp tycon' []
143         Nothing     -> return ty
144
145 boxType ty = return ty
146
147 vectAndBoxType :: Type -> VM Type
148 vectAndBoxType ty = vectType ty >>= boxType
149
150
151 -- ----------------------------------------------------------------------------
152 -- Type definitions
153
154 type TyConGroup = ([TyCon], UniqSet TyCon)
155
156 -- | Vectorise a type environment.
157 --   The type environment contains all the type things defined in a module.
158 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
159 vectTypeEnv env
160  = dtrace (ppr env)
161  $ do
162       cs <- readGEnv $ mk_map . global_tycons
163
164       -- Split the list of TyCons into the ones we have to vectorise vs the
165       -- ones we can pass through unchanged. We also pass through algebraic 
166       -- types that use non Haskell98 features, as we don't handle those.
167       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
168           keep_dcs             = concatMap tyConDataCons keep_tcs
169
170       dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
171
172       zipWithM_ defTyCon   keep_tcs keep_tcs
173       zipWithM_ defDataCon keep_dcs keep_dcs
174
175       new_tcs <- vectTyConDecls conv_tcs
176
177       dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
178
179       let orig_tcs = keep_tcs ++ conv_tcs
180
181       -- We don't need to make new representation types for dictionary
182       -- constructors. The constructors are always fully applied, and we don't 
183       -- need to lift them to arrays as a dictionary of a particular type
184       -- always has the same value.
185       let vect_tcs = filter (not . isClassTyCon) 
186                    $ keep_tcs ++ new_tcs
187
188       dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
189
190       mapM_ dumpTycon $ new_tcs
191
192
193       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
194         do
195           defTyConPAs (zipLazy vect_tcs dfuns')
196           reprs     <- mapM tyConRepr vect_tcs
197           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
198           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
199
200           dfuns     <- sequence 
201                     $  zipWith5 buildTyConBindings
202                                orig_tcs
203                                vect_tcs
204                                repr_tcs
205                                pdata_tcs
206                                reprs
207
208           binds     <- takeHoisted
209           return (dfuns, binds, repr_tcs ++ pdata_tcs)
210
211       let all_new_tcs = new_tcs ++ inst_tcs
212
213       let new_env = extendTypeEnvList env
214                        (map ATyCon all_new_tcs
215                         ++ [ADataCon dc | tc <- all_new_tcs
216                                         , dc <- tyConDataCons tc])
217
218       return (new_env, map mkLocalFamInst inst_tcs, binds)
219   where
220     tycons = typeEnvTyCons env
221     groups = tyConGroups tycons
222
223     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
224
225
226 -- | Vectorise some (possibly recursively defined) type constructors.
227 vectTyConDecls :: [TyCon] -> VM [TyCon]
228 vectTyConDecls tcs = fixV $ \tcs' ->
229   do
230     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
231     mapM vectTyConDecl tcs
232
233 dumpTycon :: TyCon -> VM ()
234 dumpTycon tycon
235         | Just cls      <- tyConClass_maybe tycon
236         = dtrace (vcat  [ ppr tycon
237                         , ppr [(m, varType m) | m <- classMethods cls ]])
238         $ return ()
239                 
240         | otherwise
241         = return ()
242
243
244 -- | Vectorise a single type construcrtor.
245 vectTyConDecl :: TyCon -> VM TyCon
246 vectTyConDecl tycon
247     -- a type class constructor.
248     -- TODO: check for no stupid theta, fds, assoc types. 
249     | isClassTyCon tycon
250     , Just cls          <- tyConClass_maybe tycon
251
252     = do    -- make the name of the vectorised class tycon.
253             name'       <- cloneName mkVectTyConOcc (tyConName tycon)
254
255             -- vectorise right of definition.
256             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
257
258             -- vectorise method selectors.
259             -- This also adds a mapping between the original and vectorised method selector
260             -- to the state.
261             methods'    <- mapM vectMethod
262                         $  [(id, defMethSpecOfDefMeth meth) 
263                                 | (id, meth)    <- classOpItems cls]
264
265             -- keep the original recursiveness flag.
266             let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
267         
268             -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
269             cls'     <- liftDs 
270                     $  buildClass
271                              False               -- include unfoldings on dictionary selectors.
272                              name'               -- new name  V_T:Class
273                              (tyConTyVars tycon) -- keep original type vars
274                              []                  -- no stupid theta
275                              []                  -- no functional dependencies
276                              []                  -- no associated types
277                              methods'            -- method info
278                              rec_flag            -- whether recursive
279
280             let tycon'  = mkClassTyCon name'
281                              (tyConKind tycon)
282                              (tyConTyVars tycon)
283                              rhs'
284                              cls'
285                              rec_flag
286
287             return $ tycon'
288                         
289     -- a regular algebraic type constructor.
290     -- TODO: check for stupid theta, generaics, GADTS etc
291     | isAlgTyCon tycon
292     = do    name'       <- cloneName mkVectTyConOcc (tyConName tycon)
293             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
294             let rec_flag =  boolToRecFlag (isRecursiveTyCon tycon)
295
296             liftDs $ buildAlgTyCon 
297                             name'               -- new name
298                             (tyConTyVars tycon) -- keep original type vars.
299                             []                  -- no stupid theta.
300                             rhs'                -- new constructor defs.
301                             rec_flag            -- FIXME: is this ok?
302                             False               -- FIXME: no generics
303                             False               -- not GADT syntax
304                             Nothing             -- not a family instance
305
306     -- some other crazy thing that we don't handle.
307     | otherwise
308     = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
309
310
311 -- | Vectorise a class method.
312 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
313 vectMethod (id, defMeth)
314  = do   
315         -- Vectorise the method type.
316         typ'    <- vectType (varType id)
317
318         -- Create a name for the vectorised method.
319         id'     <- cloneId mkVectOcc id typ'
320         defGlobalVar id id'
321
322         -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
323         -- to the types of each method. However, the types we get back from vectType
324         -- above already already have these, so we need to chop them off here otherwise
325         -- we'll get two copies in the final version.
326         let (_tyvars, tyBody) = splitForAllTys typ'
327         let (_dict,   tyRest) = splitFunTy tyBody
328
329         return  (Var.varName id', defMeth, tyRest)
330
331
332 -- | Vectorise the RHS of an algebraic type.
333 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
334 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
335                              , is_enum   = is_enum
336                              })
337   = do
338       data_cons' <- mapM vectDataCon data_cons
339       zipWithM_ defDataCon data_cons data_cons'
340       return $ DataTyCon { data_cons = data_cons'
341                          , is_enum   = is_enum
342                          }
343
344 vectAlgTyConRhs tc _ 
345         = cantVectorise "Can't vectorise type definition:" (ppr tc)
346
347
348 -- | Vectorise a data constructor.
349 --   Vectorises its argument and return types.
350 vectDataCon :: DataCon -> VM DataCon
351 vectDataCon dc
352   | not . null $ dataConExTyVars dc
353   = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
354
355   | not . null $ dataConEqSpec   dc
356   = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
357
358   | otherwise
359   = do
360       name'    <- cloneName mkVectDataConOcc name
361       tycon'   <- vectTyCon tycon
362       arg_tys  <- mapM vectType rep_arg_tys
363
364       liftDs $ buildDataCon 
365                 name'
366                 False                          -- not infix
367                 (map (const HsNoBang) arg_tys) -- strictness annots on args.
368                 []                             -- no labelled fields
369                 univ_tvs                       -- universally quantified vars
370                 []                             -- no existential tvs for now
371                 []                             -- no eq spec for now
372                 []                             -- no context
373                 arg_tys                        -- argument types
374                 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
375                 tycon'                         -- representation tycon
376   where
377     name        = dataConName dc
378     univ_tvs    = dataConUnivTyVars dc
379     rep_arg_tys = dataConRepArgTys dc
380     tycon       = dataConTyCon dc
381
382 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
383 mk_fam_inst fam_tc arg_tc
384   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
385
386
387 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
388 buildPReprTyCon orig_tc vect_tc repr
389   = do
390       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
391       -- rhs_ty   <- buildPReprType vect_tc
392       rhs_ty   <- sumReprType repr
393       prepr_tc <- builtin preprTyCon
394       liftDs $ buildSynTyCon name
395                              tyvars
396                              (SynonymTyCon rhs_ty)
397                              (typeKind rhs_ty)
398                              (Just $ mk_fam_inst prepr_tc vect_tc)
399   where
400     tyvars = tyConTyVars vect_tc
401
402 data CompRepr = Keep Type
403                      CoreExpr     -- PR dictionary for the type
404               | Wrap Type
405
406 data ProdRepr = EmptyProd
407               | UnaryProd CompRepr
408               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
409                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
410                      , repr_comp_tys :: [Type] -- representation types of
411                      , repr_comps    :: [CompRepr]          -- components
412                      }
413 data ConRepr  = ConRepr DataCon ProdRepr
414
415 data SumRepr  = EmptySum
416               | UnarySum ConRepr
417               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
418                      , repr_psum_tc  :: TyCon  -- PData representation tycon
419                      , repr_sel_ty   :: Type   -- type of selector
420                      , repr_con_tys :: [Type]  -- representation types of
421                      , repr_cons     :: [ConRepr]           -- components
422                      }
423
424 tyConRepr :: TyCon -> VM SumRepr
425 tyConRepr tc = sum_repr (tyConDataCons tc)
426   where
427     sum_repr []    = return EmptySum
428     sum_repr [con] = liftM UnarySum (con_repr con)
429     sum_repr cons  = do
430                        rs     <- mapM con_repr cons
431                        sum_tc <- builtin (sumTyCon arity)
432                        tys    <- mapM conReprType rs
433                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
434                        sel_ty <- builtin (selTy arity)
435                        return $ Sum { repr_sum_tc  = sum_tc
436                                     , repr_psum_tc = psum_tc
437                                     , repr_sel_ty  = sel_ty
438                                     , repr_con_tys = tys
439                                     , repr_cons    = rs
440                                     }
441       where
442         arity = length cons
443
444     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
445
446     prod_repr []   = return EmptyProd
447     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
448     prod_repr tys  = do
449                        rs <- mapM comp_repr tys
450                        tup_tc <- builtin (prodTyCon arity)
451                        tys'    <- mapM compReprType rs
452                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
453                        return $ Prod { repr_tup_tc   = tup_tc
454                                      , repr_ptup_tc  = ptup_tc
455                                      , repr_comp_tys = tys'
456                                      , repr_comps    = rs
457                                      }
458       where
459         arity = length tys
460     
461     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
462                    `orElseV` return (Wrap ty)
463
464 sumReprType :: SumRepr -> VM Type
465 sumReprType EmptySum = voidType
466 sumReprType (UnarySum r) = conReprType r
467 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
468   = return $ mkTyConApp sum_tc tys
469
470 conReprType :: ConRepr -> VM Type
471 conReprType (ConRepr _ r) = prodReprType r
472
473 prodReprType :: ProdRepr -> VM Type
474 prodReprType EmptyProd = voidType
475 prodReprType (UnaryProd r) = compReprType r
476 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
477   = return $ mkTyConApp tup_tc tys
478
479 compReprType :: CompRepr -> VM Type
480 compReprType (Keep ty _) = return ty
481 compReprType (Wrap ty) = do
482                              wrap_tc <- builtin wrapTyCon
483                              return $ mkTyConApp wrap_tc [ty]
484
485 compOrigType :: CompRepr -> Type
486 compOrigType (Keep ty _) = ty
487 compOrigType (Wrap ty) = ty
488
489 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
490 buildToPRepr vect_tc repr_tc _ repr
491   = do
492       let arg_ty = mkTyConApp vect_tc ty_args
493       res_ty <- mkPReprType arg_ty
494       arg    <- newLocalVar (fsLit "x") arg_ty
495       result <- to_sum (Var arg) arg_ty res_ty repr
496       return $ Lam arg result
497   where
498     ty_args = mkTyVarTys (tyConTyVars vect_tc)
499
500     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
501
502     to_sum _ _ _ EmptySum
503       = do
504           void <- builtin voidVar
505           return $ wrap_repr_inst $ Var void
506
507     to_sum arg arg_ty res_ty (UnarySum r)
508       = do
509           (pat, vars, body) <- con_alt r
510           return $ mkWildCase arg arg_ty res_ty
511                    [(pat, vars, wrap_repr_inst body)]
512
513     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
514                                   , repr_con_tys = tys
515                                   , repr_cons    =  cons })
516       = do
517           alts <- mapM con_alt cons
518           let alts' = [(pat, vars, wrap_repr_inst
519                                    $ mkConApp sum_con (map Type tys ++ [body]))
520                         | ((pat, vars, body), sum_con)
521                             <- zip alts (tyConDataCons sum_tc)]
522           return $ mkWildCase arg arg_ty res_ty alts'
523
524     con_alt (ConRepr con r)
525       = do
526           (vars, body) <- to_prod r
527           return (DataAlt con, vars, body)
528
529     to_prod EmptyProd
530       = do
531           void <- builtin voidVar
532           return ([], Var void)
533
534     to_prod (UnaryProd comp)
535       = do
536           var  <- newLocalVar (fsLit "x") (compOrigType comp)
537           body <- to_comp (Var var) comp
538           return ([var], body)
539
540     to_prod(Prod { repr_tup_tc   = tup_tc
541                  , repr_comp_tys = tys
542                  , repr_comps    = comps })
543       = do
544           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
545           exprs <- zipWithM to_comp (map Var vars) comps
546           return (vars, mkConApp tup_con (map Type tys ++ exprs))
547       where
548         [tup_con] = tyConDataCons tup_tc
549
550     to_comp expr (Keep _ _) = return expr
551     to_comp expr (Wrap ty)  = do
552                                 wrap_tc <- builtin wrapTyCon
553                                 return $ wrapNewTypeBody wrap_tc [ty] expr
554
555
556 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
557 buildFromPRepr vect_tc repr_tc _ repr
558   = do
559       arg_ty <- mkPReprType res_ty
560       arg <- newLocalVar (fsLit "x") arg_ty
561
562       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
563                          repr
564       return $ Lam arg result
565   where
566     ty_args = mkTyVarTys (tyConTyVars vect_tc)
567     res_ty  = mkTyConApp vect_tc ty_args
568
569     from_sum _ EmptySum
570       = do
571           dummy <- builtin fromVoidVar
572           return $ Var dummy `App` Type res_ty
573
574     from_sum expr (UnarySum r) = from_con expr r
575     from_sum expr (Sum { repr_sum_tc  = sum_tc
576                        , repr_con_tys = tys
577                        , repr_cons    = cons })
578       = do
579           vars  <- newLocalVars (fsLit "x") tys
580           es    <- zipWithM from_con (map Var vars) cons
581           return $ mkWildCase expr (exprType expr) res_ty
582                    [(DataAlt con, [var], e)
583                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
584
585     from_con expr (ConRepr con r)
586       = from_prod expr (mkConApp con $ map Type ty_args) r
587
588     from_prod _ con EmptyProd = return con
589     from_prod expr con (UnaryProd r)
590       = do
591           e <- from_comp expr r
592           return $ con `App` e
593      
594     from_prod expr con (Prod { repr_tup_tc   = tup_tc
595                              , repr_comp_tys = tys
596                              , repr_comps    = comps
597                              })
598       = do
599           vars <- newLocalVars (fsLit "y") tys
600           es   <- zipWithM from_comp (map Var vars) comps
601           return $ mkWildCase expr (exprType expr) res_ty
602                    [(DataAlt tup_con, vars, con `mkApps` es)]
603       where
604         [tup_con] = tyConDataCons tup_tc  
605
606     from_comp expr (Keep _ _) = return expr
607     from_comp expr (Wrap ty)
608       = do
609           wrap <- builtin wrapTyCon
610           return $ unwrapNewTypeBody wrap [ty] expr
611
612
613 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
614 buildToArrPRepr vect_tc prepr_tc pdata_tc r
615   = do
616       arg_ty <- mkPDataType el_ty
617       res_ty <- mkPDataType =<< mkPReprType el_ty
618       arg    <- newLocalVar (fsLit "xs") arg_ty
619
620       pdata_co <- mkBuiltinCo pdataTyCon
621       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
622           co           = mkAppCoercion pdata_co
623                        . mkSymCoercion
624                        $ mkTyConApp repr_co ty_args
625
626           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
627
628       (vars, result) <- to_sum r
629
630       return . Lam arg
631              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
632                [(DataAlt pdata_dc, vars, mkCoerce co result)]
633   where
634     ty_args = mkTyVarTys $ tyConTyVars vect_tc
635     el_ty   = mkTyConApp vect_tc ty_args
636
637     [pdata_dc] = tyConDataCons pdata_tc
638
639
640     to_sum EmptySum = do
641                         pvoid <- builtin pvoidVar
642                         return ([], Var pvoid)
643     to_sum (UnarySum r) = to_con r
644     to_sum (Sum { repr_psum_tc = psum_tc
645                 , repr_sel_ty  = sel_ty
646                 , repr_con_tys = tys
647                 , repr_cons    = cons
648                 })
649       = do
650           (vars, exprs) <- mapAndUnzipM to_con cons
651           sel <- newLocalVar (fsLit "sel") sel_ty
652           return (sel : concat vars, mk_result (Var sel) exprs)
653       where
654         [psum_con] = tyConDataCons psum_tc
655         mk_result sel exprs = wrapFamInstBody psum_tc tys
656                             $ mkConApp psum_con
657                             $ map Type tys ++ (sel : exprs)
658
659     to_con (ConRepr _ r) = to_prod r
660
661     to_prod EmptyProd = do
662                           pvoid <- builtin pvoidVar
663                           return ([], Var pvoid)
664     to_prod (UnaryProd r)
665       = do
666           pty  <- mkPDataType (compOrigType r)
667           var  <- newLocalVar (fsLit "x") pty
668           expr <- to_comp (Var var) r
669           return ([var], expr)
670
671     to_prod (Prod { repr_ptup_tc  = ptup_tc
672                   , repr_comp_tys = tys
673                   , repr_comps    = comps })
674       = do
675           ptys <- mapM (mkPDataType . compOrigType) comps
676           vars <- newLocalVars (fsLit "x") ptys
677           es   <- zipWithM to_comp (map Var vars) comps
678           return (vars, mk_result es)
679       where
680         [ptup_con] = tyConDataCons ptup_tc
681         mk_result exprs = wrapFamInstBody ptup_tc tys
682                         $ mkConApp ptup_con
683                         $ map Type tys ++ exprs
684
685     to_comp expr (Keep _ _) = return expr
686
687     -- FIXME: this is bound to be wrong!
688     to_comp expr (Wrap ty)
689       = do
690           wrap_tc  <- builtin wrapTyCon
691           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
692           return $ wrapNewTypeBody pwrap_tc [ty] expr
693
694
695 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
696 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
697   = do
698       arg_ty <- mkPDataType =<< mkPReprType el_ty
699       res_ty <- mkPDataType el_ty
700       arg    <- newLocalVar (fsLit "xs") arg_ty
701
702       pdata_co <- mkBuiltinCo pdataTyCon
703       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
704           co           = mkAppCoercion pdata_co
705                        $ mkTyConApp repr_co var_tys
706
707           scrut  = mkCoerce co (Var arg)
708
709           mk_result args = wrapFamInstBody pdata_tc var_tys
710                          $ mkConApp pdata_con
711                          $ map Type var_tys ++ args
712
713       (expr, _) <- fixV $ \ ~(_, args) ->
714                      from_sum res_ty (mk_result args) scrut r
715
716       return $ Lam arg expr
717     
718       -- (args, mk) <- from_sum res_ty scrut r
719       
720       -- let result = wrapFamInstBody pdata_tc var_tys
721       --           . mkConApp pdata_dc
722       --           $ map Type var_tys ++ args
723
724       -- return $ Lam arg (mk result)
725   where
726     var_tys = mkTyVarTys $ tyConTyVars vect_tc
727     el_ty   = mkTyConApp vect_tc var_tys
728
729     [pdata_con] = tyConDataCons pdata_tc
730
731     from_sum _ res _ EmptySum = return (res, [])
732     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
733     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
734                                   , repr_sel_ty  = sel_ty
735                                   , repr_con_tys = tys
736                                   , repr_cons    = cons })
737       = do
738           sel  <- newLocalVar (fsLit "sel") sel_ty
739           ptys <- mapM mkPDataType tys
740           vars <- newLocalVars (fsLit "xs") ptys
741           (res', args) <- fold from_con res_ty res (map Var vars) cons
742           let scrut = unwrapFamInstScrut psum_tc tys expr
743               body  = mkWildCase scrut (exprType scrut) res_ty
744                       [(DataAlt psum_con, sel : vars, res')]
745           return (body, Var sel : args)
746       where
747         [psum_con] = tyConDataCons psum_tc
748
749
750     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
751
752     from_prod _ res _ EmptyProd = return (res, [])
753     from_prod res_ty res expr (UnaryProd r)
754       = from_comp res_ty res expr r
755     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
756                                     , repr_comp_tys = tys
757                                     , repr_comps    = comps })
758       = do
759           ptys <- mapM mkPDataType tys
760           vars <- newLocalVars (fsLit "ys") ptys
761           (res', args) <- fold from_comp res_ty res (map Var vars) comps
762           let scrut = unwrapFamInstScrut ptup_tc tys expr
763               body  = mkWildCase scrut (exprType scrut) res_ty
764                       [(DataAlt ptup_con, vars, res')]
765           return (body, args)
766       where
767         [ptup_con] = tyConDataCons ptup_tc
768
769     from_comp _ res expr (Keep _ _) = return (res, [expr])
770     from_comp _ res expr (Wrap ty)
771       = do
772           wrap_tc  <- builtin wrapTyCon
773           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
774           return (res, [unwrapNewTypeBody pwrap_tc [ty]
775                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
776
777     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
778       where
779         f' (expr, r) (res, args) = do
780                                      (res', args') <- f res_ty res expr r
781                                      return (res', args' ++ args)
782
783 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
784 buildPRDict vect_tc prepr_tc _ r
785   = do
786       dict <- sum_dict r
787       pr_co <- mkBuiltinCo prTyCon
788       let co = mkAppCoercion pr_co
789              . mkSymCoercion
790              $ mkTyConApp arg_co ty_args
791       return (mkCoerce co dict)
792   where
793     ty_args = mkTyVarTys (tyConTyVars vect_tc)
794     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
795
796     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
797     sum_dict (UnarySum r) = con_dict r
798     sum_dict (Sum { repr_sum_tc  = sum_tc
799                   , repr_con_tys = tys
800                   , repr_cons    = cons
801                   })
802       = do
803           dicts <- mapM con_dict cons
804           dfun  <- prDFunOfTyCon sum_tc
805           return $ dfun `mkTyApps` tys `mkApps` dicts
806
807     con_dict (ConRepr _ r) = prod_dict r
808
809     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
810     prod_dict (UnaryProd r) = comp_dict r
811     prod_dict (Prod { repr_tup_tc   = tup_tc
812                     , repr_comp_tys = tys
813                     , repr_comps    = comps })
814       = do
815           dicts <- mapM comp_dict comps
816           dfun <- prDFunOfTyCon tup_tc
817           return $ dfun `mkTyApps` tys `mkApps` dicts
818
819     comp_dict (Keep _ pr) = return pr
820     comp_dict (Wrap ty)   = wrapPR ty
821
822
823 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
824 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
825   do
826     name' <- cloneName mkPDataTyConOcc orig_name
827     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
828     pdata <- builtin pdataTyCon
829
830     liftDs $ buildAlgTyCon name'
831                            tyvars
832                            []          -- no stupid theta
833                            rhs
834                            rec_flag    -- FIXME: is this ok?
835                            False       -- FIXME: no generics
836                            False       -- not GADT syntax
837                            (Just $ mk_fam_inst pdata vect_tc)
838   where
839     orig_name = tyConName orig_tc
840     tyvars = tyConTyVars vect_tc
841     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
842
843
844 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
845 buildPDataTyConRhs orig_name vect_tc repr_tc repr
846   = do
847       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
848       return $ DataTyCon { data_cons = [data_con], is_enum = False }
849
850 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
851 buildPDataDataCon orig_name vect_tc repr_tc repr
852   = do
853       dc_name  <- cloneName mkPDataDataConOcc orig_name
854       comp_tys <- sum_tys repr
855
856       liftDs $ buildDataCon dc_name
857                             False                  -- not infix
858                             (map (const HsNoBang) comp_tys)
859                             []                     -- no field labels
860                             tvs
861                             []                     -- no existentials
862                             []                     -- no eq spec
863                             []                     -- no context
864                             comp_tys
865                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
866                             repr_tc
867   where
868     tvs   = tyConTyVars vect_tc
869
870     sum_tys EmptySum = return []
871     sum_tys (UnarySum r) = con_tys r
872     sum_tys (Sum { repr_sel_ty = sel_ty
873                  , repr_cons   = cons })
874       = liftM (sel_ty :) (concatMapM con_tys cons)
875
876     con_tys (ConRepr _ r) = prod_tys r
877
878     prod_tys EmptyProd = return []
879     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
880     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
881
882     comp_ty r = mkPDataType (compOrigType r)
883
884
885 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
886                    -> VM Var
887 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
888   = do
889       vectDataConWorkers orig_tc vect_tc pdata_tc
890       buildPADict vect_tc prepr_tc pdata_tc repr
891
892 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
893 vectDataConWorkers orig_tc vect_tc arr_tc
894   = do
895       bs <- sequence
896           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
897           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
898                                  rep_tys
899                                  (inits rep_tys)
900                                  (tail $ tails rep_tys)
901       mapM_ (uncurry hoistBinding) bs
902   where
903     tyvars   = tyConTyVars vect_tc
904     var_tys  = mkTyVarTys tyvars
905     ty_args  = map Type var_tys
906     res_ty   = mkTyConApp vect_tc var_tys
907
908     cons     = tyConDataCons vect_tc
909     arity    = length cons
910     [arr_dc] = tyConDataCons arr_tc
911
912     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
913
914
915     mk_data_con con tys pre post
916       = liftM2 (,) (vect_data_con con)
917                    (lift_data_con tys pre post (mkDataConTag con))
918
919     sel_replicate len tag
920       | arity > 1 = do
921                       rep <- builtin (selReplicate arity)
922                       return [rep `mkApps` [len, tag]]
923
924       | otherwise = return []
925
926     vect_data_con con = return $ mkConApp con ty_args
927     lift_data_con tys pre_tys post_tys tag
928       = do
929           len  <- builtin liftingContext
930           args <- mapM (newLocalVar (fsLit "xs"))
931                   =<< mapM mkPDataType tys
932
933           sel  <- sel_replicate (Var len) tag
934
935           pre   <- mapM emptyPD (concat pre_tys)
936           post  <- mapM emptyPD (concat post_tys)
937
938           return . mkLams (len : args)
939                  . wrapFamInstBody arr_tc var_tys
940                  . mkConApp arr_dc
941                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
942
943     def_worker data_con arg_tys mk_body
944       = do
945           arity <- polyArity tyvars
946           body <- closedV
947                 . inBind orig_worker
948                 . polyAbstract tyvars $ \args ->
949                   liftM (mkLams (tyvars ++ args) . vectorised)
950                 $ buildClosures tyvars [] arg_tys res_ty mk_body
951
952           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
953           let vect_worker = raw_worker `setIdUnfolding`
954                               mkInlineRule body (Just arity)
955           defGlobalVar orig_worker vect_worker
956           return (vect_worker, body)
957       where
958         orig_worker = dataConWorkId data_con
959
960 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
961 buildPADict vect_tc prepr_tc arr_tc repr
962   = polyAbstract tvs $ \args ->
963     do
964       method_ids <- mapM (method args) paMethods
965
966       pa_tc  <- builtin paTyCon
967       pa_dc  <- builtin paDataCon
968       let dict = mkLams (tvs ++ args)
969                $ mkConApp pa_dc
970                $ Type inst_ty : map (method_call args) method_ids
971
972           dfun_ty = mkForAllTys tvs
973                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
974
975       raw_dfun <- newExportedVar dfun_name dfun_ty
976       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
977                           `setInlinePragma` dfunInlinePragma
978
979       hoistBinding dfun dict
980       return dfun
981   where
982     tvs = tyConTyVars vect_tc
983     arg_tys = mkTyVarTys tvs
984     inst_ty = mkTyConApp vect_tc arg_tys
985
986     dfun_name = mkPADFunOcc (getOccName vect_tc)
987
988     method args (name, build)
989       = localV
990       $ do
991           expr <- build vect_tc prepr_tc arr_tc repr
992           let body = mkLams (tvs ++ args) expr
993           raw_var <- newExportedVar (method_name name) (exprType body)
994           let var = raw_var
995                       `setIdUnfolding` mkInlineRule body (Just (length args))
996                       `setInlinePragma` alwaysInlinePragma
997           hoistBinding var body
998           return var
999
1000     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1001
1002     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1003
1004
1005 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1006 paMethods = [("dictPRepr",    buildPRDict),
1007              ("toPRepr",      buildToPRepr),
1008              ("fromPRepr",    buildFromPRepr),
1009              ("toArrPRepr",   buildToArrPRepr),
1010              ("fromArrPRepr", buildFromArrPRepr)]
1011
1012
1013 -- | Split the given tycons into two sets depending on whether they have to be
1014 --   converted (first list) or not (second list). The first argument contains
1015 --   information about the conversion status of external tycons:
1016 --
1017 --   * tycons which have converted versions are mapped to True
1018 --   * tycons which are not changed by vectorisation are mapped to False
1019 --   * tycons which can't be converted are not elements of the map
1020 --
1021 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1022 classifyTyCons = classify [] []
1023   where
1024     classify conv keep _  [] = (conv, keep)
1025     classify conv keep cs ((tcs, ds) : rs)
1026       | can_convert && must_convert
1027         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1028       | can_convert
1029         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1030       | otherwise
1031         = classify conv keep cs rs
1032       where
1033         refs = ds `delListFromUniqSet` tcs
1034
1035         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1036         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1037
1038         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1039
1040 -- | Compute mutually recursive groups of tycons in topological order
1041 --
1042 tyConGroups :: [TyCon] -> [TyConGroup]
1043 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1044   where
1045     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1046                                 , let ds = tyConsOfTyCon tc]
1047
1048     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1049     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
1050       where
1051         (tcs, dss) = unzip els
1052
1053 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1054 tyConsOfTyCon
1055   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1056
1057 tyConsOfType :: Type -> UniqSet TyCon
1058 tyConsOfType ty
1059   | Just ty' <- coreView ty    = tyConsOfType ty'
1060 tyConsOfType (TyVarTy _)       = emptyUniqSet
1061 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1062   where
1063     extend | isUnLiftedTyCon tc
1064            || isTupleTyCon   tc = id
1065
1066            | otherwise          = (`addOneToUniqSet` tc)
1067
1068 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
1069 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1070                                  `addOneToUniqSet` funTyCon
1071 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
1072 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1073
1074 tyConsOfTypes :: [Type] -> UniqSet TyCon
1075 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1076
1077
1078 -- ----------------------------------------------------------------------------
1079 -- Conversions
1080
1081 -- | Build an expression that calls the vectorised version of some 
1082 --   function from a `Closure`.
1083 --
1084 --   For example
1085 --   @   
1086 --      \(x :: Double) -> 
1087 --      \(y :: Double) -> 
1088 --      ($v_foo $: x) $: y
1089 --   @
1090 --
1091 --   We use the type of the original binding to work out how many
1092 --   outer lambdas to add.
1093 --
1094 fromVect 
1095         :: Type         -- ^ The type of the original binding.
1096         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
1097         -> VM CoreExpr
1098         
1099 -- Convert the type to the core view if it isn't already.
1100 fromVect ty expr 
1101         | Just ty' <- coreView ty 
1102         = fromVect ty' expr
1103
1104 -- For each function constructor in the original type we add an outer 
1105 -- lambda to bind the parameter variable, and an inner application of it.
1106 fromVect (FunTy arg_ty res_ty) expr
1107   = do
1108       arg     <- newLocalVar (fsLit "x") arg_ty
1109       varg    <- toVect arg_ty (Var arg)
1110       varg_ty <- vectType arg_ty
1111       vres_ty <- vectType res_ty
1112       apply   <- builtin applyVar
1113       body    <- fromVect res_ty
1114                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1115       return $ Lam arg body
1116
1117 -- If the type isn't a function then it's time to call on the closure.
1118 fromVect ty expr
1119   = identityConv ty >> return expr
1120
1121
1122 toVect :: Type -> CoreExpr -> VM CoreExpr
1123 toVect ty expr = identityConv ty >> return expr
1124
1125
1126 identityConv :: Type -> VM ()
1127 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1128 identityConv (TyConApp tycon tys)
1129   = do
1130       mapM_ identityConv tys
1131       identityConvTyCon tycon
1132 identityConv _ = noV
1133
1134 identityConvTyCon :: TyCon -> VM ()
1135 identityConvTyCon tc
1136   | isBoxedTupleTyCon tc = return ()
1137   | isUnLiftedTyCon tc   = return ()
1138   | otherwise            = do
1139                              tc' <- maybeV (lookupTyCon tc)
1140                              if tc == tc' then return () else noV
1141