0004defeba2d78a0dc8104675a5a6f51d95cf0c3
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectUtils
10 import Vectorise.Env
11 import Vectorise.Vect
12 import Vectorise.Monad
13 import Vectorise.Builtins
14
15 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
16 import BasicTypes
17 import CoreSyn
18 import CoreUtils
19 import CoreUnfold
20 import MkCore            ( mkWildCase )
21 import BuildTyCl
22 import DataCon
23 import TyCon
24 import Class
25 import Type
26 import TypeRep
27 import Coercion
28 import FamInstEnv        ( FamInst, mkLocalFamInst )
29 import OccName
30 import Id
31 import MkId
32 import Var               ( Var, TyVar, varType, varName )
33 import Name              ( Name, getOccName )
34 import NameEnv
35
36 import Unique
37 import UniqFM
38 import UniqSet
39 import Util
40 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
41
42 import Outputable
43 import FastString
44
45 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
46 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
47 import Data.List
48 import Data.Maybe
49
50 debug           = False
51 dtrace s x      = if debug then pprTrace "VectType" s x else x
52
53 -- ----------------------------------------------------------------------------
54 -- Types
55
56 -- | Vectorise a type constructor.
57 vectTyCon :: TyCon -> VM TyCon
58 vectTyCon tc
59   | isFunTyCon tc        = builtin closureTyCon
60   | isBoxedTupleTyCon tc = return tc
61   | isUnLiftedTyCon tc   = return tc
62   | otherwise            
63   = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
64         $ lookupTyCon tc
65
66
67 vectAndLiftType :: Type -> VM (Type, Type)
68 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
69 vectAndLiftType ty
70   = do
71       mdicts   <- mapM paDictArgType tyvars
72       let dicts = [dict | Just dict <- mdicts]
73       vmono_ty <- vectType mono_ty
74       lmono_ty <- mkPDataType vmono_ty
75       return (abstractType tyvars dicts vmono_ty,
76               abstractType tyvars dicts lmono_ty)
77   where
78     (tyvars, mono_ty) = splitForAllTys ty
79
80
81 -- | Vectorise a type.
82 vectType :: Type -> VM Type
83 vectType ty
84         | Just ty'      <- coreView ty
85         = vectType ty'
86         
87 vectType (TyVarTy tv)           = return $ TyVarTy tv
88 vectType (AppTy ty1 ty2)        = liftM2 AppTy    (vectType ty1) (vectType ty2)
89 vectType (TyConApp tc tys)      = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
90 vectType (FunTy ty1 ty2)        = liftM2 TyConApp (builtin closureTyCon)
91                                                   (mapM vectAndBoxType [ty1,ty2])
92
93 -- For each quantified var we need to add a PA dictionary out the front of the type.
94 -- So          forall a. C  a => a -> a   
95 -- turns into  forall a. Cv a => PA a => a :-> a
96 vectType ty@(ForAllTy _ _)
97  = do
98       -- split the type into the quantified vars, its dictionaries and the body.
99       let (tyvars, tyBody)   = splitForAllTys ty
100       let (tyArgs, tyResult) = splitFunTys    tyBody
101
102       let (tyArgs_dict, tyArgs_regular) 
103                   = partition isDictType tyArgs
104
105       -- vectorise the body.
106       let tyBody' = mkFunTys tyArgs_regular tyResult
107       tyBody''    <- vectType tyBody'
108
109       -- vectorise the dictionary parameters.
110       dictsVect   <- mapM vectType tyArgs_dict
111
112       -- make a PA dictionary for each of the type variables.
113       dictsPA     <- liftM catMaybes $ mapM paDictArgType tyvars
114
115       -- pack it all back together.
116       return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
117
118 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
119
120
121 -- | Add quantified vars and dictionary parameters to the front of a type.
122 abstractType :: [TyVar] -> [Type] -> Type -> Type
123 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
124
125
126 -- | Check if some type is a type class dictionary.
127 isDictType :: Type -> Bool
128 isDictType ty
129  = case splitTyConApp_maybe ty of
130         Just (tyCon, _)         -> isClassTyCon tyCon
131         _                       -> False
132
133         
134 -- ----------------------------------------------------------------------------
135 -- Boxing
136
137 boxType :: Type -> VM Type
138 boxType ty
139   | Just (tycon, []) <- splitTyConApp_maybe ty
140   , isUnLiftedTyCon tycon
141   = do
142       r <- lookupBoxedTyCon tycon
143       case r of
144         Just tycon' -> return $ mkTyConApp tycon' []
145         Nothing     -> return ty
146
147 boxType ty = return ty
148
149 vectAndBoxType :: Type -> VM Type
150 vectAndBoxType ty = vectType ty >>= boxType
151
152
153 -- ----------------------------------------------------------------------------
154 -- Type definitions
155
156 type TyConGroup = ([TyCon], UniqSet TyCon)
157
158 -- | Vectorise a type environment.
159 --   The type environment contains all the type things defined in a module.
160 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
161 vectTypeEnv env
162  = dtrace (ppr env)
163  $ do
164       cs <- readGEnv $ mk_map . global_tycons
165
166       -- Split the list of TyCons into the ones we have to vectorise vs the
167       -- ones we can pass through unchanged. We also pass through algebraic 
168       -- types that use non Haskell98 features, as we don't handle those.
169       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
170           keep_dcs             = concatMap tyConDataCons keep_tcs
171
172       dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
173
174       zipWithM_ defTyCon   keep_tcs keep_tcs
175       zipWithM_ defDataCon keep_dcs keep_dcs
176
177       new_tcs <- vectTyConDecls conv_tcs
178
179       dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
180
181       let orig_tcs = keep_tcs ++ conv_tcs
182
183       -- We don't need to make new representation types for dictionary
184       -- constructors. The constructors are always fully applied, and we don't 
185       -- need to lift them to arrays as a dictionary of a particular type
186       -- always has the same value.
187       let vect_tcs = filter (not . isClassTyCon) 
188                    $ keep_tcs ++ new_tcs
189
190       dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
191
192       mapM_ dumpTycon $ new_tcs
193
194
195       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
196         do
197           defTyConPAs (zipLazy vect_tcs dfuns')
198           reprs     <- mapM tyConRepr vect_tcs
199           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
200           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
201
202           dfuns     <- sequence 
203                     $  zipWith5 buildTyConBindings
204                                orig_tcs
205                                vect_tcs
206                                repr_tcs
207                                pdata_tcs
208                                reprs
209
210           binds     <- takeHoisted
211           return (dfuns, binds, repr_tcs ++ pdata_tcs)
212
213       let all_new_tcs = new_tcs ++ inst_tcs
214
215       let new_env = extendTypeEnvList env
216                        (map ATyCon all_new_tcs
217                         ++ [ADataCon dc | tc <- all_new_tcs
218                                         , dc <- tyConDataCons tc])
219
220       return (new_env, map mkLocalFamInst inst_tcs, binds)
221   where
222     tycons = typeEnvTyCons env
223     groups = tyConGroups tycons
224
225     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
226
227
228 -- | Vectorise some (possibly recursively defined) type constructors.
229 vectTyConDecls :: [TyCon] -> VM [TyCon]
230 vectTyConDecls tcs = fixV $ \tcs' ->
231   do
232     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
233     mapM vectTyConDecl tcs
234
235 dumpTycon :: TyCon -> VM ()
236 dumpTycon tycon
237         | Just cls      <- tyConClass_maybe tycon
238         = dtrace (vcat  [ ppr tycon
239                         , ppr [(m, varType m) | m <- classMethods cls ]])
240         $ return ()
241                 
242         | otherwise
243         = return ()
244
245
246 -- | Vectorise a single type construcrtor.
247 vectTyConDecl :: TyCon -> VM TyCon
248 vectTyConDecl tycon
249     -- a type class constructor.
250     -- TODO: check for no stupid theta, fds, assoc types. 
251     | isClassTyCon tycon
252     , Just cls          <- tyConClass_maybe tycon
253
254     = do    -- make the name of the vectorised class tycon.
255             name'       <- cloneName mkVectTyConOcc (tyConName tycon)
256
257             -- vectorise right of definition.
258             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
259
260             -- vectorise method selectors.
261             -- This also adds a mapping between the original and vectorised method selector
262             -- to the state.
263             methods'    <- mapM vectMethod
264                         $  [(id, defMethSpecOfDefMeth meth) 
265                                 | (id, meth)    <- classOpItems cls]
266
267             -- keep the original recursiveness flag.
268             let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
269         
270             -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
271             cls'     <- liftDs 
272                     $  buildClass
273                              False               -- include unfoldings on dictionary selectors.
274                              name'               -- new name  V_T:Class
275                              (tyConTyVars tycon) -- keep original type vars
276                              []                  -- no stupid theta
277                              []                  -- no functional dependencies
278                              []                  -- no associated types
279                              methods'            -- method info
280                              rec_flag            -- whether recursive
281
282             let tycon'  = mkClassTyCon name'
283                              (tyConKind tycon)
284                              (tyConTyVars tycon)
285                              rhs'
286                              cls'
287                              rec_flag
288
289             return $ tycon'
290                         
291     -- a regular algebraic type constructor.
292     -- TODO: check for stupid theta, generaics, GADTS etc
293     | isAlgTyCon tycon
294     = do    name'       <- cloneName mkVectTyConOcc (tyConName tycon)
295             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
296             let rec_flag =  boolToRecFlag (isRecursiveTyCon tycon)
297
298             liftDs $ buildAlgTyCon 
299                             name'               -- new name
300                             (tyConTyVars tycon) -- keep original type vars.
301                             []                  -- no stupid theta.
302                             rhs'                -- new constructor defs.
303                             rec_flag            -- FIXME: is this ok?
304                             False               -- FIXME: no generics
305                             False               -- not GADT syntax
306                             Nothing             -- not a family instance
307
308     -- some other crazy thing that we don't handle.
309     | otherwise
310     = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
311
312
313 -- | Vectorise a class method.
314 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
315 vectMethod (id, defMeth)
316  = do   
317         -- Vectorise the method type.
318         typ'    <- vectType (varType id)
319
320         -- Create a name for the vectorised method.
321         id'     <- cloneId mkVectOcc id typ'
322         defGlobalVar id id'
323
324         -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
325         -- to the types of each method. However, the types we get back from vectType
326         -- above already already have these, so we need to chop them off here otherwise
327         -- we'll get two copies in the final version.
328         let (_tyvars, tyBody) = splitForAllTys typ'
329         let (_dict,   tyRest) = splitFunTy tyBody
330
331         return  (Var.varName id', defMeth, tyRest)
332
333
334 -- | Vectorise the RHS of an algebraic type.
335 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
336 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
337                              , is_enum   = is_enum
338                              })
339   = do
340       data_cons' <- mapM vectDataCon data_cons
341       zipWithM_ defDataCon data_cons data_cons'
342       return $ DataTyCon { data_cons = data_cons'
343                          , is_enum   = is_enum
344                          }
345
346 vectAlgTyConRhs tc _ 
347         = cantVectorise "Can't vectorise type definition:" (ppr tc)
348
349
350 -- | Vectorise a data constructor.
351 --   Vectorises its argument and return types.
352 vectDataCon :: DataCon -> VM DataCon
353 vectDataCon dc
354   | not . null $ dataConExTyVars dc
355   = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
356
357   | not . null $ dataConEqSpec   dc
358   = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
359
360   | otherwise
361   = do
362       name'    <- cloneName mkVectDataConOcc name
363       tycon'   <- vectTyCon tycon
364       arg_tys  <- mapM vectType rep_arg_tys
365
366       liftDs $ buildDataCon 
367                 name'
368                 False                          -- not infix
369                 (map (const HsNoBang) arg_tys) -- strictness annots on args.
370                 []                             -- no labelled fields
371                 univ_tvs                       -- universally quantified vars
372                 []                             -- no existential tvs for now
373                 []                             -- no eq spec for now
374                 []                             -- no context
375                 arg_tys                        -- argument types
376                 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
377                 tycon'                         -- representation tycon
378   where
379     name        = dataConName dc
380     univ_tvs    = dataConUnivTyVars dc
381     rep_arg_tys = dataConRepArgTys dc
382     tycon       = dataConTyCon dc
383
384 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
385 mk_fam_inst fam_tc arg_tc
386   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
387
388
389 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
390 buildPReprTyCon orig_tc vect_tc repr
391   = do
392       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
393       -- rhs_ty   <- buildPReprType vect_tc
394       rhs_ty   <- sumReprType repr
395       prepr_tc <- builtin preprTyCon
396       liftDs $ buildSynTyCon name
397                              tyvars
398                              (SynonymTyCon rhs_ty)
399                              (typeKind rhs_ty)
400                              (Just $ mk_fam_inst prepr_tc vect_tc)
401   where
402     tyvars = tyConTyVars vect_tc
403
404 data CompRepr = Keep Type
405                      CoreExpr     -- PR dictionary for the type
406               | Wrap Type
407
408 data ProdRepr = EmptyProd
409               | UnaryProd CompRepr
410               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
411                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
412                      , repr_comp_tys :: [Type] -- representation types of
413                      , repr_comps    :: [CompRepr]          -- components
414                      }
415 data ConRepr  = ConRepr DataCon ProdRepr
416
417 data SumRepr  = EmptySum
418               | UnarySum ConRepr
419               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
420                      , repr_psum_tc  :: TyCon  -- PData representation tycon
421                      , repr_sel_ty   :: Type   -- type of selector
422                      , repr_con_tys :: [Type]  -- representation types of
423                      , repr_cons     :: [ConRepr]           -- components
424                      }
425
426 tyConRepr :: TyCon -> VM SumRepr
427 tyConRepr tc = sum_repr (tyConDataCons tc)
428   where
429     sum_repr []    = return EmptySum
430     sum_repr [con] = liftM UnarySum (con_repr con)
431     sum_repr cons  = do
432                        rs     <- mapM con_repr cons
433                        sum_tc <- builtin (sumTyCon arity)
434                        tys    <- mapM conReprType rs
435                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
436                        sel_ty <- builtin (selTy arity)
437                        return $ Sum { repr_sum_tc  = sum_tc
438                                     , repr_psum_tc = psum_tc
439                                     , repr_sel_ty  = sel_ty
440                                     , repr_con_tys = tys
441                                     , repr_cons    = rs
442                                     }
443       where
444         arity = length cons
445
446     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
447
448     prod_repr []   = return EmptyProd
449     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
450     prod_repr tys  = do
451                        rs <- mapM comp_repr tys
452                        tup_tc <- builtin (prodTyCon arity)
453                        tys'    <- mapM compReprType rs
454                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
455                        return $ Prod { repr_tup_tc   = tup_tc
456                                      , repr_ptup_tc  = ptup_tc
457                                      , repr_comp_tys = tys'
458                                      , repr_comps    = rs
459                                      }
460       where
461         arity = length tys
462     
463     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
464                    `orElseV` return (Wrap ty)
465
466 sumReprType :: SumRepr -> VM Type
467 sumReprType EmptySum = voidType
468 sumReprType (UnarySum r) = conReprType r
469 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
470   = return $ mkTyConApp sum_tc tys
471
472 conReprType :: ConRepr -> VM Type
473 conReprType (ConRepr _ r) = prodReprType r
474
475 prodReprType :: ProdRepr -> VM Type
476 prodReprType EmptyProd = voidType
477 prodReprType (UnaryProd r) = compReprType r
478 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
479   = return $ mkTyConApp tup_tc tys
480
481 compReprType :: CompRepr -> VM Type
482 compReprType (Keep ty _) = return ty
483 compReprType (Wrap ty) = do
484                              wrap_tc <- builtin wrapTyCon
485                              return $ mkTyConApp wrap_tc [ty]
486
487 compOrigType :: CompRepr -> Type
488 compOrigType (Keep ty _) = ty
489 compOrigType (Wrap ty) = ty
490
491 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
492 buildToPRepr vect_tc repr_tc _ repr
493   = do
494       let arg_ty = mkTyConApp vect_tc ty_args
495       res_ty <- mkPReprType arg_ty
496       arg    <- newLocalVar (fsLit "x") arg_ty
497       result <- to_sum (Var arg) arg_ty res_ty repr
498       return $ Lam arg result
499   where
500     ty_args = mkTyVarTys (tyConTyVars vect_tc)
501
502     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
503
504     to_sum _ _ _ EmptySum
505       = do
506           void <- builtin voidVar
507           return $ wrap_repr_inst $ Var void
508
509     to_sum arg arg_ty res_ty (UnarySum r)
510       = do
511           (pat, vars, body) <- con_alt r
512           return $ mkWildCase arg arg_ty res_ty
513                    [(pat, vars, wrap_repr_inst body)]
514
515     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
516                                   , repr_con_tys = tys
517                                   , repr_cons    =  cons })
518       = do
519           alts <- mapM con_alt cons
520           let alts' = [(pat, vars, wrap_repr_inst
521                                    $ mkConApp sum_con (map Type tys ++ [body]))
522                         | ((pat, vars, body), sum_con)
523                             <- zip alts (tyConDataCons sum_tc)]
524           return $ mkWildCase arg arg_ty res_ty alts'
525
526     con_alt (ConRepr con r)
527       = do
528           (vars, body) <- to_prod r
529           return (DataAlt con, vars, body)
530
531     to_prod EmptyProd
532       = do
533           void <- builtin voidVar
534           return ([], Var void)
535
536     to_prod (UnaryProd comp)
537       = do
538           var  <- newLocalVar (fsLit "x") (compOrigType comp)
539           body <- to_comp (Var var) comp
540           return ([var], body)
541
542     to_prod(Prod { repr_tup_tc   = tup_tc
543                  , repr_comp_tys = tys
544                  , repr_comps    = comps })
545       = do
546           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
547           exprs <- zipWithM to_comp (map Var vars) comps
548           return (vars, mkConApp tup_con (map Type tys ++ exprs))
549       where
550         [tup_con] = tyConDataCons tup_tc
551
552     to_comp expr (Keep _ _) = return expr
553     to_comp expr (Wrap ty)  = do
554                                 wrap_tc <- builtin wrapTyCon
555                                 return $ wrapNewTypeBody wrap_tc [ty] expr
556
557
558 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
559 buildFromPRepr vect_tc repr_tc _ repr
560   = do
561       arg_ty <- mkPReprType res_ty
562       arg <- newLocalVar (fsLit "x") arg_ty
563
564       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
565                          repr
566       return $ Lam arg result
567   where
568     ty_args = mkTyVarTys (tyConTyVars vect_tc)
569     res_ty  = mkTyConApp vect_tc ty_args
570
571     from_sum _ EmptySum
572       = do
573           dummy <- builtin fromVoidVar
574           return $ Var dummy `App` Type res_ty
575
576     from_sum expr (UnarySum r) = from_con expr r
577     from_sum expr (Sum { repr_sum_tc  = sum_tc
578                        , repr_con_tys = tys
579                        , repr_cons    = cons })
580       = do
581           vars  <- newLocalVars (fsLit "x") tys
582           es    <- zipWithM from_con (map Var vars) cons
583           return $ mkWildCase expr (exprType expr) res_ty
584                    [(DataAlt con, [var], e)
585                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
586
587     from_con expr (ConRepr con r)
588       = from_prod expr (mkConApp con $ map Type ty_args) r
589
590     from_prod _ con EmptyProd = return con
591     from_prod expr con (UnaryProd r)
592       = do
593           e <- from_comp expr r
594           return $ con `App` e
595      
596     from_prod expr con (Prod { repr_tup_tc   = tup_tc
597                              , repr_comp_tys = tys
598                              , repr_comps    = comps
599                              })
600       = do
601           vars <- newLocalVars (fsLit "y") tys
602           es   <- zipWithM from_comp (map Var vars) comps
603           return $ mkWildCase expr (exprType expr) res_ty
604                    [(DataAlt tup_con, vars, con `mkApps` es)]
605       where
606         [tup_con] = tyConDataCons tup_tc  
607
608     from_comp expr (Keep _ _) = return expr
609     from_comp expr (Wrap ty)
610       = do
611           wrap <- builtin wrapTyCon
612           return $ unwrapNewTypeBody wrap [ty] expr
613
614
615 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
616 buildToArrPRepr vect_tc prepr_tc pdata_tc r
617   = do
618       arg_ty <- mkPDataType el_ty
619       res_ty <- mkPDataType =<< mkPReprType el_ty
620       arg    <- newLocalVar (fsLit "xs") arg_ty
621
622       pdata_co <- mkBuiltinCo pdataTyCon
623       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
624           co           = mkAppCoercion pdata_co
625                        . mkSymCoercion
626                        $ mkTyConApp repr_co ty_args
627
628           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
629
630       (vars, result) <- to_sum r
631
632       return . Lam arg
633              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
634                [(DataAlt pdata_dc, vars, mkCoerce co result)]
635   where
636     ty_args = mkTyVarTys $ tyConTyVars vect_tc
637     el_ty   = mkTyConApp vect_tc ty_args
638
639     [pdata_dc] = tyConDataCons pdata_tc
640
641
642     to_sum EmptySum = do
643                         pvoid <- builtin pvoidVar
644                         return ([], Var pvoid)
645     to_sum (UnarySum r) = to_con r
646     to_sum (Sum { repr_psum_tc = psum_tc
647                 , repr_sel_ty  = sel_ty
648                 , repr_con_tys = tys
649                 , repr_cons    = cons
650                 })
651       = do
652           (vars, exprs) <- mapAndUnzipM to_con cons
653           sel <- newLocalVar (fsLit "sel") sel_ty
654           return (sel : concat vars, mk_result (Var sel) exprs)
655       where
656         [psum_con] = tyConDataCons psum_tc
657         mk_result sel exprs = wrapFamInstBody psum_tc tys
658                             $ mkConApp psum_con
659                             $ map Type tys ++ (sel : exprs)
660
661     to_con (ConRepr _ r) = to_prod r
662
663     to_prod EmptyProd = do
664                           pvoid <- builtin pvoidVar
665                           return ([], Var pvoid)
666     to_prod (UnaryProd r)
667       = do
668           pty  <- mkPDataType (compOrigType r)
669           var  <- newLocalVar (fsLit "x") pty
670           expr <- to_comp (Var var) r
671           return ([var], expr)
672
673     to_prod (Prod { repr_ptup_tc  = ptup_tc
674                   , repr_comp_tys = tys
675                   , repr_comps    = comps })
676       = do
677           ptys <- mapM (mkPDataType . compOrigType) comps
678           vars <- newLocalVars (fsLit "x") ptys
679           es   <- zipWithM to_comp (map Var vars) comps
680           return (vars, mk_result es)
681       where
682         [ptup_con] = tyConDataCons ptup_tc
683         mk_result exprs = wrapFamInstBody ptup_tc tys
684                         $ mkConApp ptup_con
685                         $ map Type tys ++ exprs
686
687     to_comp expr (Keep _ _) = return expr
688
689     -- FIXME: this is bound to be wrong!
690     to_comp expr (Wrap ty)
691       = do
692           wrap_tc  <- builtin wrapTyCon
693           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
694           return $ wrapNewTypeBody pwrap_tc [ty] expr
695
696
697 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
698 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
699   = do
700       arg_ty <- mkPDataType =<< mkPReprType el_ty
701       res_ty <- mkPDataType el_ty
702       arg    <- newLocalVar (fsLit "xs") arg_ty
703
704       pdata_co <- mkBuiltinCo pdataTyCon
705       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
706           co           = mkAppCoercion pdata_co
707                        $ mkTyConApp repr_co var_tys
708
709           scrut  = mkCoerce co (Var arg)
710
711           mk_result args = wrapFamInstBody pdata_tc var_tys
712                          $ mkConApp pdata_con
713                          $ map Type var_tys ++ args
714
715       (expr, _) <- fixV $ \ ~(_, args) ->
716                      from_sum res_ty (mk_result args) scrut r
717
718       return $ Lam arg expr
719     
720       -- (args, mk) <- from_sum res_ty scrut r
721       
722       -- let result = wrapFamInstBody pdata_tc var_tys
723       --           . mkConApp pdata_dc
724       --           $ map Type var_tys ++ args
725
726       -- return $ Lam arg (mk result)
727   where
728     var_tys = mkTyVarTys $ tyConTyVars vect_tc
729     el_ty   = mkTyConApp vect_tc var_tys
730
731     [pdata_con] = tyConDataCons pdata_tc
732
733     from_sum _ res _ EmptySum = return (res, [])
734     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
735     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
736                                   , repr_sel_ty  = sel_ty
737                                   , repr_con_tys = tys
738                                   , repr_cons    = cons })
739       = do
740           sel  <- newLocalVar (fsLit "sel") sel_ty
741           ptys <- mapM mkPDataType tys
742           vars <- newLocalVars (fsLit "xs") ptys
743           (res', args) <- fold from_con res_ty res (map Var vars) cons
744           let scrut = unwrapFamInstScrut psum_tc tys expr
745               body  = mkWildCase scrut (exprType scrut) res_ty
746                       [(DataAlt psum_con, sel : vars, res')]
747           return (body, Var sel : args)
748       where
749         [psum_con] = tyConDataCons psum_tc
750
751
752     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
753
754     from_prod _ res _ EmptyProd = return (res, [])
755     from_prod res_ty res expr (UnaryProd r)
756       = from_comp res_ty res expr r
757     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
758                                     , repr_comp_tys = tys
759                                     , repr_comps    = comps })
760       = do
761           ptys <- mapM mkPDataType tys
762           vars <- newLocalVars (fsLit "ys") ptys
763           (res', args) <- fold from_comp res_ty res (map Var vars) comps
764           let scrut = unwrapFamInstScrut ptup_tc tys expr
765               body  = mkWildCase scrut (exprType scrut) res_ty
766                       [(DataAlt ptup_con, vars, res')]
767           return (body, args)
768       where
769         [ptup_con] = tyConDataCons ptup_tc
770
771     from_comp _ res expr (Keep _ _) = return (res, [expr])
772     from_comp _ res expr (Wrap ty)
773       = do
774           wrap_tc  <- builtin wrapTyCon
775           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
776           return (res, [unwrapNewTypeBody pwrap_tc [ty]
777                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
778
779     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
780       where
781         f' (expr, r) (res, args) = do
782                                      (res', args') <- f res_ty res expr r
783                                      return (res', args' ++ args)
784
785 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
786 buildPRDict vect_tc prepr_tc _ r
787   = do
788       dict <- sum_dict r
789       pr_co <- mkBuiltinCo prTyCon
790       let co = mkAppCoercion pr_co
791              . mkSymCoercion
792              $ mkTyConApp arg_co ty_args
793       return (mkCoerce co dict)
794   where
795     ty_args = mkTyVarTys (tyConTyVars vect_tc)
796     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
797
798     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
799     sum_dict (UnarySum r) = con_dict r
800     sum_dict (Sum { repr_sum_tc  = sum_tc
801                   , repr_con_tys = tys
802                   , repr_cons    = cons
803                   })
804       = do
805           dicts <- mapM con_dict cons
806           dfun  <- prDFunOfTyCon sum_tc
807           return $ dfun `mkTyApps` tys `mkApps` dicts
808
809     con_dict (ConRepr _ r) = prod_dict r
810
811     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
812     prod_dict (UnaryProd r) = comp_dict r
813     prod_dict (Prod { repr_tup_tc   = tup_tc
814                     , repr_comp_tys = tys
815                     , repr_comps    = comps })
816       = do
817           dicts <- mapM comp_dict comps
818           dfun <- prDFunOfTyCon tup_tc
819           return $ dfun `mkTyApps` tys `mkApps` dicts
820
821     comp_dict (Keep _ pr) = return pr
822     comp_dict (Wrap ty)   = wrapPR ty
823
824
825 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
826 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
827   do
828     name' <- cloneName mkPDataTyConOcc orig_name
829     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
830     pdata <- builtin pdataTyCon
831
832     liftDs $ buildAlgTyCon name'
833                            tyvars
834                            []          -- no stupid theta
835                            rhs
836                            rec_flag    -- FIXME: is this ok?
837                            False       -- FIXME: no generics
838                            False       -- not GADT syntax
839                            (Just $ mk_fam_inst pdata vect_tc)
840   where
841     orig_name = tyConName orig_tc
842     tyvars = tyConTyVars vect_tc
843     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
844
845
846 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
847 buildPDataTyConRhs orig_name vect_tc repr_tc repr
848   = do
849       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
850       return $ DataTyCon { data_cons = [data_con], is_enum = False }
851
852 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
853 buildPDataDataCon orig_name vect_tc repr_tc repr
854   = do
855       dc_name  <- cloneName mkPDataDataConOcc orig_name
856       comp_tys <- sum_tys repr
857
858       liftDs $ buildDataCon dc_name
859                             False                  -- not infix
860                             (map (const HsNoBang) comp_tys)
861                             []                     -- no field labels
862                             tvs
863                             []                     -- no existentials
864                             []                     -- no eq spec
865                             []                     -- no context
866                             comp_tys
867                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
868                             repr_tc
869   where
870     tvs   = tyConTyVars vect_tc
871
872     sum_tys EmptySum = return []
873     sum_tys (UnarySum r) = con_tys r
874     sum_tys (Sum { repr_sel_ty = sel_ty
875                  , repr_cons   = cons })
876       = liftM (sel_ty :) (concatMapM con_tys cons)
877
878     con_tys (ConRepr _ r) = prod_tys r
879
880     prod_tys EmptyProd = return []
881     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
882     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
883
884     comp_ty r = mkPDataType (compOrigType r)
885
886
887 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
888                    -> VM Var
889 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
890   = do
891       vectDataConWorkers orig_tc vect_tc pdata_tc
892       buildPADict vect_tc prepr_tc pdata_tc repr
893
894 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
895 vectDataConWorkers orig_tc vect_tc arr_tc
896   = do
897       bs <- sequence
898           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
899           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
900                                  rep_tys
901                                  (inits rep_tys)
902                                  (tail $ tails rep_tys)
903       mapM_ (uncurry hoistBinding) bs
904   where
905     tyvars   = tyConTyVars vect_tc
906     var_tys  = mkTyVarTys tyvars
907     ty_args  = map Type var_tys
908     res_ty   = mkTyConApp vect_tc var_tys
909
910     cons     = tyConDataCons vect_tc
911     arity    = length cons
912     [arr_dc] = tyConDataCons arr_tc
913
914     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
915
916
917     mk_data_con con tys pre post
918       = liftM2 (,) (vect_data_con con)
919                    (lift_data_con tys pre post (mkDataConTag con))
920
921     sel_replicate len tag
922       | arity > 1 = do
923                       rep <- builtin (selReplicate arity)
924                       return [rep `mkApps` [len, tag]]
925
926       | otherwise = return []
927
928     vect_data_con con = return $ mkConApp con ty_args
929     lift_data_con tys pre_tys post_tys tag
930       = do
931           len  <- builtin liftingContext
932           args <- mapM (newLocalVar (fsLit "xs"))
933                   =<< mapM mkPDataType tys
934
935           sel  <- sel_replicate (Var len) tag
936
937           pre   <- mapM emptyPD (concat pre_tys)
938           post  <- mapM emptyPD (concat post_tys)
939
940           return . mkLams (len : args)
941                  . wrapFamInstBody arr_tc var_tys
942                  . mkConApp arr_dc
943                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
944
945     def_worker data_con arg_tys mk_body
946       = do
947           arity <- polyArity tyvars
948           body <- closedV
949                 . inBind orig_worker
950                 . polyAbstract tyvars $ \args ->
951                   liftM (mkLams (tyvars ++ args) . vectorised)
952                 $ buildClosures tyvars [] arg_tys res_ty mk_body
953
954           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
955           let vect_worker = raw_worker `setIdUnfolding`
956                               mkInlineRule body (Just arity)
957           defGlobalVar orig_worker vect_worker
958           return (vect_worker, body)
959       where
960         orig_worker = dataConWorkId data_con
961
962 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
963 buildPADict vect_tc prepr_tc arr_tc repr
964   = polyAbstract tvs $ \args ->
965     do
966       method_ids <- mapM (method args) paMethods
967
968       pa_tc  <- builtin paTyCon
969       pa_dc  <- builtin paDataCon
970       let dict = mkLams (tvs ++ args)
971                $ mkConApp pa_dc
972                $ Type inst_ty : map (method_call args) method_ids
973
974           dfun_ty = mkForAllTys tvs
975                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
976
977       raw_dfun <- newExportedVar dfun_name dfun_ty
978       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
979                           `setInlinePragma` dfunInlinePragma
980
981       hoistBinding dfun dict
982       return dfun
983   where
984     tvs = tyConTyVars vect_tc
985     arg_tys = mkTyVarTys tvs
986     inst_ty = mkTyConApp vect_tc arg_tys
987
988     dfun_name = mkPADFunOcc (getOccName vect_tc)
989
990     method args (name, build)
991       = localV
992       $ do
993           expr <- build vect_tc prepr_tc arr_tc repr
994           let body = mkLams (tvs ++ args) expr
995           raw_var <- newExportedVar (method_name name) (exprType body)
996           let var = raw_var
997                       `setIdUnfolding` mkInlineRule body (Just (length args))
998                       `setInlinePragma` alwaysInlinePragma
999           hoistBinding var body
1000           return var
1001
1002     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1003
1004     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1005
1006
1007 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1008 paMethods = [("dictPRepr",    buildPRDict),
1009              ("toPRepr",      buildToPRepr),
1010              ("fromPRepr",    buildFromPRepr),
1011              ("toArrPRepr",   buildToArrPRepr),
1012              ("fromArrPRepr", buildFromArrPRepr)]
1013
1014
1015 -- | Split the given tycons into two sets depending on whether they have to be
1016 --   converted (first list) or not (second list). The first argument contains
1017 --   information about the conversion status of external tycons:
1018 --
1019 --   * tycons which have converted versions are mapped to True
1020 --   * tycons which are not changed by vectorisation are mapped to False
1021 --   * tycons which can't be converted are not elements of the map
1022 --
1023 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1024 classifyTyCons = classify [] []
1025   where
1026     classify conv keep _  [] = (conv, keep)
1027     classify conv keep cs ((tcs, ds) : rs)
1028       | can_convert && must_convert
1029         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1030       | can_convert
1031         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1032       | otherwise
1033         = classify conv keep cs rs
1034       where
1035         refs = ds `delListFromUniqSet` tcs
1036
1037         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1038         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1039
1040         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1041
1042 -- | Compute mutually recursive groups of tycons in topological order
1043 --
1044 tyConGroups :: [TyCon] -> [TyConGroup]
1045 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1046   where
1047     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1048                                 , let ds = tyConsOfTyCon tc]
1049
1050     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1051     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
1052       where
1053         (tcs, dss) = unzip els
1054
1055 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1056 tyConsOfTyCon
1057   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1058
1059 tyConsOfType :: Type -> UniqSet TyCon
1060 tyConsOfType ty
1061   | Just ty' <- coreView ty    = tyConsOfType ty'
1062 tyConsOfType (TyVarTy _)       = emptyUniqSet
1063 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1064   where
1065     extend | isUnLiftedTyCon tc
1066            || isTupleTyCon   tc = id
1067
1068            | otherwise          = (`addOneToUniqSet` tc)
1069
1070 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
1071 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1072                                  `addOneToUniqSet` funTyCon
1073 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
1074 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1075
1076 tyConsOfTypes :: [Type] -> UniqSet TyCon
1077 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1078
1079
1080 -- ----------------------------------------------------------------------------
1081 -- Conversions
1082
1083 -- | Build an expression that calls the vectorised version of some 
1084 --   function from a `Closure`.
1085 --
1086 --   For example
1087 --   @   
1088 --      \(x :: Double) -> 
1089 --      \(y :: Double) -> 
1090 --      ($v_foo $: x) $: y
1091 --   @
1092 --
1093 --   We use the type of the original binding to work out how many
1094 --   outer lambdas to add.
1095 --
1096 fromVect 
1097         :: Type         -- ^ The type of the original binding.
1098         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
1099         -> VM CoreExpr
1100         
1101 -- Convert the type to the core view if it isn't already.
1102 fromVect ty expr 
1103         | Just ty' <- coreView ty 
1104         = fromVect ty' expr
1105
1106 -- For each function constructor in the original type we add an outer 
1107 -- lambda to bind the parameter variable, and an inner application of it.
1108 fromVect (FunTy arg_ty res_ty) expr
1109   = do
1110       arg     <- newLocalVar (fsLit "x") arg_ty
1111       varg    <- toVect arg_ty (Var arg)
1112       varg_ty <- vectType arg_ty
1113       vres_ty <- vectType res_ty
1114       apply   <- builtin applyVar
1115       body    <- fromVect res_ty
1116                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1117       return $ Lam arg body
1118
1119 -- If the type isn't a function then it's time to call on the closure.
1120 fromVect ty expr
1121   = identityConv ty >> return expr
1122
1123
1124 toVect :: Type -> CoreExpr -> VM CoreExpr
1125 toVect ty expr = identityConv ty >> return expr
1126
1127
1128 identityConv :: Type -> VM ()
1129 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1130 identityConv (TyConApp tycon tys)
1131   = do
1132       mapM_ identityConv tys
1133       identityConvTyCon tycon
1134 identityConv _ = noV
1135
1136 identityConvTyCon :: TyCon -> VM ()
1137 identityConvTyCon tc
1138   | isBoxedTupleTyCon tc = return ()
1139   | isUnLiftedTyCon tc   = return ()
1140   | otherwise            = do
1141                              tc' <- maybeV (lookupTyCon tc)
1142                              if tc == tc' then return () else noV
1143