Move VectCore to Vectorise tree
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4                   -- arrSumArity, pdataCompTys, pdataCompVars,
5                   buildPADict,
6                   fromVect )
7 where
8
9 import VectMonad
10 import VectUtils
11 import Vectorise.Env
12 import Vectorise.Vect
13
14 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
15 import BasicTypes
16 import CoreSyn
17 import CoreUtils
18 import CoreUnfold
19 import MkCore            ( mkWildCase )
20 import BuildTyCl
21 import DataCon
22 import TyCon
23 import Class
24 import Type
25 import TypeRep
26 import Coercion
27 import FamInstEnv        ( FamInst, mkLocalFamInst )
28 import OccName
29 import Id
30 import MkId
31 import Var               ( Var, TyVar, varType, varName )
32 import Name              ( Name, getOccName )
33 import NameEnv
34
35 import Unique
36 import UniqFM
37 import UniqSet
38 import Util
39 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
40
41 import Outputable
42 import FastString
43
44 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
45 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
46 import Data.List
47 import Data.Maybe
48
49 debug           = False
50 dtrace s x      = if debug then pprTrace "VectType" s x else x
51
52 -- ----------------------------------------------------------------------------
53 -- Types
54
55 -- | Vectorise a type constructor.
56 vectTyCon :: TyCon -> VM TyCon
57 vectTyCon tc
58   | isFunTyCon tc        = builtin closureTyCon
59   | isBoxedTupleTyCon tc = return tc
60   | isUnLiftedTyCon tc   = return tc
61   | otherwise            
62   = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
63         $ lookupTyCon tc
64
65
66 vectAndLiftType :: Type -> VM (Type, Type)
67 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
68 vectAndLiftType ty
69   = do
70       mdicts   <- mapM paDictArgType tyvars
71       let dicts = [dict | Just dict <- mdicts]
72       vmono_ty <- vectType mono_ty
73       lmono_ty <- mkPDataType vmono_ty
74       return (abstractType tyvars dicts vmono_ty,
75               abstractType tyvars dicts lmono_ty)
76   where
77     (tyvars, mono_ty) = splitForAllTys ty
78
79
80 -- | Vectorise a type.
81 vectType :: Type -> VM Type
82 vectType ty
83         | Just ty'      <- coreView ty
84         = vectType ty'
85         
86 vectType (TyVarTy tv)           = return $ TyVarTy tv
87 vectType (AppTy ty1 ty2)        = liftM2 AppTy    (vectType ty1) (vectType ty2)
88 vectType (TyConApp tc tys)      = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
89 vectType (FunTy ty1 ty2)        = liftM2 TyConApp (builtin closureTyCon)
90                                                   (mapM vectAndBoxType [ty1,ty2])
91
92 -- For each quantified var we need to add a PA dictionary out the front of the type.
93 -- So          forall a. C  a => a -> a   
94 -- turns into  forall a. Cv a => PA a => a :-> a
95 vectType ty@(ForAllTy _ _)
96  = do
97       -- split the type into the quantified vars, its dictionaries and the body.
98       let (tyvars, tyBody)   = splitForAllTys ty
99       let (tyArgs, tyResult) = splitFunTys    tyBody
100
101       let (tyArgs_dict, tyArgs_regular) 
102                   = partition isDictType tyArgs
103
104       -- vectorise the body.
105       let tyBody' = mkFunTys tyArgs_regular tyResult
106       tyBody''    <- vectType tyBody'
107
108       -- vectorise the dictionary parameters.
109       dictsVect   <- mapM vectType tyArgs_dict
110
111       -- make a PA dictionary for each of the type variables.
112       dictsPA     <- liftM catMaybes $ mapM paDictArgType tyvars
113
114       -- pack it all back together.
115       return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
116
117 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
118
119
120 -- | Add quantified vars and dictionary parameters to the front of a type.
121 abstractType :: [TyVar] -> [Type] -> Type -> Type
122 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
123
124
125 -- | Check if some type is a type class dictionary.
126 isDictType :: Type -> Bool
127 isDictType ty
128  = case splitTyConApp_maybe ty of
129         Just (tyCon, _)         -> isClassTyCon tyCon
130         _                       -> False
131
132         
133 -- ----------------------------------------------------------------------------
134 -- Boxing
135
136 boxType :: Type -> VM Type
137 boxType ty
138   | Just (tycon, []) <- splitTyConApp_maybe ty
139   , isUnLiftedTyCon tycon
140   = do
141       r <- lookupBoxedTyCon tycon
142       case r of
143         Just tycon' -> return $ mkTyConApp tycon' []
144         Nothing     -> return ty
145
146 boxType ty = return ty
147
148 vectAndBoxType :: Type -> VM Type
149 vectAndBoxType ty = vectType ty >>= boxType
150
151
152 -- ----------------------------------------------------------------------------
153 -- Type definitions
154
155 type TyConGroup = ([TyCon], UniqSet TyCon)
156
157 -- | Vectorise a type environment.
158 --   The type environment contains all the type things defined in a module.
159 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
160 vectTypeEnv env
161  = dtrace (ppr env)
162  $ do
163       cs <- readGEnv $ mk_map . global_tycons
164
165       -- Split the list of TyCons into the ones we have to vectorise vs the
166       -- ones we can pass through unchanged. We also pass through algebraic 
167       -- types that use non Haskell98 features, as we don't handle those.
168       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
169           keep_dcs             = concatMap tyConDataCons keep_tcs
170
171       dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
172
173       zipWithM_ defTyCon   keep_tcs keep_tcs
174       zipWithM_ defDataCon keep_dcs keep_dcs
175
176       new_tcs <- vectTyConDecls conv_tcs
177
178       dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
179
180       let orig_tcs = keep_tcs ++ conv_tcs
181
182       -- We don't need to make new representation types for dictionary
183       -- constructors. The constructors are always fully applied, and we don't 
184       -- need to lift them to arrays as a dictionary of a particular type
185       -- always has the same value.
186       let vect_tcs = filter (not . isClassTyCon) 
187                    $ keep_tcs ++ new_tcs
188
189       dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
190
191       mapM_ dumpTycon $ new_tcs
192
193
194       (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
195         do
196           defTyConPAs (zipLazy vect_tcs dfuns')
197           reprs     <- mapM tyConRepr vect_tcs
198           repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
199           pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
200
201           dfuns     <- sequence 
202                     $  zipWith5 buildTyConBindings
203                                orig_tcs
204                                vect_tcs
205                                repr_tcs
206                                pdata_tcs
207                                reprs
208
209           binds     <- takeHoisted
210           return (dfuns, binds, repr_tcs ++ pdata_tcs)
211
212       let all_new_tcs = new_tcs ++ inst_tcs
213
214       let new_env = extendTypeEnvList env
215                        (map ATyCon all_new_tcs
216                         ++ [ADataCon dc | tc <- all_new_tcs
217                                         , dc <- tyConDataCons tc])
218
219       return (new_env, map mkLocalFamInst inst_tcs, binds)
220   where
221     tycons = typeEnvTyCons env
222     groups = tyConGroups tycons
223
224     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
225
226
227 -- | Vectorise some (possibly recursively defined) type constructors.
228 vectTyConDecls :: [TyCon] -> VM [TyCon]
229 vectTyConDecls tcs = fixV $ \tcs' ->
230   do
231     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
232     mapM vectTyConDecl tcs
233
234 dumpTycon :: TyCon -> VM ()
235 dumpTycon tycon
236         | Just cls      <- tyConClass_maybe tycon
237         = dtrace (vcat  [ ppr tycon
238                         , ppr [(m, varType m) | m <- classMethods cls ]])
239         $ return ()
240                 
241         | otherwise
242         = return ()
243
244
245 -- | Vectorise a single type construcrtor.
246 vectTyConDecl :: TyCon -> VM TyCon
247 vectTyConDecl tycon
248     -- a type class constructor.
249     -- TODO: check for no stupid theta, fds, assoc types. 
250     | isClassTyCon tycon
251     , Just cls          <- tyConClass_maybe tycon
252
253     = do    -- make the name of the vectorised class tycon.
254             name'       <- cloneName mkVectTyConOcc (tyConName tycon)
255
256             -- vectorise right of definition.
257             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
258
259             -- vectorise method selectors.
260             -- This also adds a mapping between the original and vectorised method selector
261             -- to the state.
262             methods'    <- mapM vectMethod
263                         $  [(id, defMethSpecOfDefMeth meth) 
264                                 | (id, meth)    <- classOpItems cls]
265
266             -- keep the original recursiveness flag.
267             let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
268         
269             -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
270             cls'     <- liftDs 
271                     $  buildClass
272                              False               -- include unfoldings on dictionary selectors.
273                              name'               -- new name  V_T:Class
274                              (tyConTyVars tycon) -- keep original type vars
275                              []                  -- no stupid theta
276                              []                  -- no functional dependencies
277                              []                  -- no associated types
278                              methods'            -- method info
279                              rec_flag            -- whether recursive
280
281             let tycon'  = mkClassTyCon name'
282                              (tyConKind tycon)
283                              (tyConTyVars tycon)
284                              rhs'
285                              cls'
286                              rec_flag
287
288             return $ tycon'
289                         
290     -- a regular algebraic type constructor.
291     -- TODO: check for stupid theta, generaics, GADTS etc
292     | isAlgTyCon tycon
293     = do    name'       <- cloneName mkVectTyConOcc (tyConName tycon)
294             rhs'        <- vectAlgTyConRhs tycon (algTyConRhs tycon)
295             let rec_flag =  boolToRecFlag (isRecursiveTyCon tycon)
296
297             liftDs $ buildAlgTyCon 
298                             name'               -- new name
299                             (tyConTyVars tycon) -- keep original type vars.
300                             []                  -- no stupid theta.
301                             rhs'                -- new constructor defs.
302                             rec_flag            -- FIXME: is this ok?
303                             False               -- FIXME: no generics
304                             False               -- not GADT syntax
305                             Nothing             -- not a family instance
306
307     -- some other crazy thing that we don't handle.
308     | otherwise
309     = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
310
311
312 -- | Vectorise a class method.
313 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
314 vectMethod (id, defMeth)
315  = do   
316         -- Vectorise the method type.
317         typ'    <- vectType (varType id)
318
319         -- Create a name for the vectorised method.
320         id'     <- cloneId mkVectOcc id typ'
321         defGlobalVar id id'
322
323         -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
324         -- to the types of each method. However, the types we get back from vectType
325         -- above already already have these, so we need to chop them off here otherwise
326         -- we'll get two copies in the final version.
327         let (_tyvars, tyBody) = splitForAllTys typ'
328         let (_dict,   tyRest) = splitFunTy tyBody
329
330         return  (Var.varName id', defMeth, tyRest)
331
332
333 -- | Vectorise the RHS of an algebraic type.
334 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
335 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
336                              , is_enum   = is_enum
337                              })
338   = do
339       data_cons' <- mapM vectDataCon data_cons
340       zipWithM_ defDataCon data_cons data_cons'
341       return $ DataTyCon { data_cons = data_cons'
342                          , is_enum   = is_enum
343                          }
344
345 vectAlgTyConRhs tc _ 
346         = cantVectorise "Can't vectorise type definition:" (ppr tc)
347
348
349 -- | Vectorise a data constructor.
350 --   Vectorises its argument and return types.
351 vectDataCon :: DataCon -> VM DataCon
352 vectDataCon dc
353   | not . null $ dataConExTyVars dc
354   = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
355
356   | not . null $ dataConEqSpec   dc
357   = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
358
359   | otherwise
360   = do
361       name'    <- cloneName mkVectDataConOcc name
362       tycon'   <- vectTyCon tycon
363       arg_tys  <- mapM vectType rep_arg_tys
364
365       liftDs $ buildDataCon 
366                 name'
367                 False                          -- not infix
368                 (map (const HsNoBang) arg_tys) -- strictness annots on args.
369                 []                             -- no labelled fields
370                 univ_tvs                       -- universally quantified vars
371                 []                             -- no existential tvs for now
372                 []                             -- no eq spec for now
373                 []                             -- no context
374                 arg_tys                        -- argument types
375                 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
376                 tycon'                         -- representation tycon
377   where
378     name        = dataConName dc
379     univ_tvs    = dataConUnivTyVars dc
380     rep_arg_tys = dataConRepArgTys dc
381     tycon       = dataConTyCon dc
382
383 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
384 mk_fam_inst fam_tc arg_tc
385   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
386
387
388 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
389 buildPReprTyCon orig_tc vect_tc repr
390   = do
391       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
392       -- rhs_ty   <- buildPReprType vect_tc
393       rhs_ty   <- sumReprType repr
394       prepr_tc <- builtin preprTyCon
395       liftDs $ buildSynTyCon name
396                              tyvars
397                              (SynonymTyCon rhs_ty)
398                              (typeKind rhs_ty)
399                              (Just $ mk_fam_inst prepr_tc vect_tc)
400   where
401     tyvars = tyConTyVars vect_tc
402
403 data CompRepr = Keep Type
404                      CoreExpr     -- PR dictionary for the type
405               | Wrap Type
406
407 data ProdRepr = EmptyProd
408               | UnaryProd CompRepr
409               | Prod { repr_tup_tc   :: TyCon  -- representation tuple tycon
410                      , repr_ptup_tc  :: TyCon  -- PData representation tycon
411                      , repr_comp_tys :: [Type] -- representation types of
412                      , repr_comps    :: [CompRepr]          -- components
413                      }
414 data ConRepr  = ConRepr DataCon ProdRepr
415
416 data SumRepr  = EmptySum
417               | UnarySum ConRepr
418               | Sum  { repr_sum_tc   :: TyCon  -- representation sum tycon
419                      , repr_psum_tc  :: TyCon  -- PData representation tycon
420                      , repr_sel_ty   :: Type   -- type of selector
421                      , repr_con_tys :: [Type]  -- representation types of
422                      , repr_cons     :: [ConRepr]           -- components
423                      }
424
425 tyConRepr :: TyCon -> VM SumRepr
426 tyConRepr tc = sum_repr (tyConDataCons tc)
427   where
428     sum_repr []    = return EmptySum
429     sum_repr [con] = liftM UnarySum (con_repr con)
430     sum_repr cons  = do
431                        rs     <- mapM con_repr cons
432                        sum_tc <- builtin (sumTyCon arity)
433                        tys    <- mapM conReprType rs
434                        (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
435                        sel_ty <- builtin (selTy arity)
436                        return $ Sum { repr_sum_tc  = sum_tc
437                                     , repr_psum_tc = psum_tc
438                                     , repr_sel_ty  = sel_ty
439                                     , repr_con_tys = tys
440                                     , repr_cons    = rs
441                                     }
442       where
443         arity = length cons
444
445     con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
446
447     prod_repr []   = return EmptyProd
448     prod_repr [ty] = liftM UnaryProd (comp_repr ty)
449     prod_repr tys  = do
450                        rs <- mapM comp_repr tys
451                        tup_tc <- builtin (prodTyCon arity)
452                        tys'    <- mapM compReprType rs
453                        (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
454                        return $ Prod { repr_tup_tc   = tup_tc
455                                      , repr_ptup_tc  = ptup_tc
456                                      , repr_comp_tys = tys'
457                                      , repr_comps    = rs
458                                      }
459       where
460         arity = length tys
461     
462     comp_repr ty = liftM (Keep ty) (prDictOfType ty)
463                    `orElseV` return (Wrap ty)
464
465 sumReprType :: SumRepr -> VM Type
466 sumReprType EmptySum = voidType
467 sumReprType (UnarySum r) = conReprType r
468 sumReprType (Sum { repr_sum_tc  = sum_tc, repr_con_tys = tys })
469   = return $ mkTyConApp sum_tc tys
470
471 conReprType :: ConRepr -> VM Type
472 conReprType (ConRepr _ r) = prodReprType r
473
474 prodReprType :: ProdRepr -> VM Type
475 prodReprType EmptyProd = voidType
476 prodReprType (UnaryProd r) = compReprType r
477 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
478   = return $ mkTyConApp tup_tc tys
479
480 compReprType :: CompRepr -> VM Type
481 compReprType (Keep ty _) = return ty
482 compReprType (Wrap ty) = do
483                              wrap_tc <- builtin wrapTyCon
484                              return $ mkTyConApp wrap_tc [ty]
485
486 compOrigType :: CompRepr -> Type
487 compOrigType (Keep ty _) = ty
488 compOrigType (Wrap ty) = ty
489
490 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
491 buildToPRepr vect_tc repr_tc _ repr
492   = do
493       let arg_ty = mkTyConApp vect_tc ty_args
494       res_ty <- mkPReprType arg_ty
495       arg    <- newLocalVar (fsLit "x") arg_ty
496       result <- to_sum (Var arg) arg_ty res_ty repr
497       return $ Lam arg result
498   where
499     ty_args = mkTyVarTys (tyConTyVars vect_tc)
500
501     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
502
503     to_sum _ _ _ EmptySum
504       = do
505           void <- builtin voidVar
506           return $ wrap_repr_inst $ Var void
507
508     to_sum arg arg_ty res_ty (UnarySum r)
509       = do
510           (pat, vars, body) <- con_alt r
511           return $ mkWildCase arg arg_ty res_ty
512                    [(pat, vars, wrap_repr_inst body)]
513
514     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
515                                   , repr_con_tys = tys
516                                   , repr_cons    =  cons })
517       = do
518           alts <- mapM con_alt cons
519           let alts' = [(pat, vars, wrap_repr_inst
520                                    $ mkConApp sum_con (map Type tys ++ [body]))
521                         | ((pat, vars, body), sum_con)
522                             <- zip alts (tyConDataCons sum_tc)]
523           return $ mkWildCase arg arg_ty res_ty alts'
524
525     con_alt (ConRepr con r)
526       = do
527           (vars, body) <- to_prod r
528           return (DataAlt con, vars, body)
529
530     to_prod EmptyProd
531       = do
532           void <- builtin voidVar
533           return ([], Var void)
534
535     to_prod (UnaryProd comp)
536       = do
537           var  <- newLocalVar (fsLit "x") (compOrigType comp)
538           body <- to_comp (Var var) comp
539           return ([var], body)
540
541     to_prod(Prod { repr_tup_tc   = tup_tc
542                  , repr_comp_tys = tys
543                  , repr_comps    = comps })
544       = do
545           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
546           exprs <- zipWithM to_comp (map Var vars) comps
547           return (vars, mkConApp tup_con (map Type tys ++ exprs))
548       where
549         [tup_con] = tyConDataCons tup_tc
550
551     to_comp expr (Keep _ _) = return expr
552     to_comp expr (Wrap ty)  = do
553                                 wrap_tc <- builtin wrapTyCon
554                                 return $ wrapNewTypeBody wrap_tc [ty] expr
555
556
557 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
558 buildFromPRepr vect_tc repr_tc _ repr
559   = do
560       arg_ty <- mkPReprType res_ty
561       arg <- newLocalVar (fsLit "x") arg_ty
562
563       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
564                          repr
565       return $ Lam arg result
566   where
567     ty_args = mkTyVarTys (tyConTyVars vect_tc)
568     res_ty  = mkTyConApp vect_tc ty_args
569
570     from_sum _ EmptySum
571       = do
572           dummy <- builtin fromVoidVar
573           return $ Var dummy `App` Type res_ty
574
575     from_sum expr (UnarySum r) = from_con expr r
576     from_sum expr (Sum { repr_sum_tc  = sum_tc
577                        , repr_con_tys = tys
578                        , repr_cons    = cons })
579       = do
580           vars  <- newLocalVars (fsLit "x") tys
581           es    <- zipWithM from_con (map Var vars) cons
582           return $ mkWildCase expr (exprType expr) res_ty
583                    [(DataAlt con, [var], e)
584                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
585
586     from_con expr (ConRepr con r)
587       = from_prod expr (mkConApp con $ map Type ty_args) r
588
589     from_prod _ con EmptyProd = return con
590     from_prod expr con (UnaryProd r)
591       = do
592           e <- from_comp expr r
593           return $ con `App` e
594      
595     from_prod expr con (Prod { repr_tup_tc   = tup_tc
596                              , repr_comp_tys = tys
597                              , repr_comps    = comps
598                              })
599       = do
600           vars <- newLocalVars (fsLit "y") tys
601           es   <- zipWithM from_comp (map Var vars) comps
602           return $ mkWildCase expr (exprType expr) res_ty
603                    [(DataAlt tup_con, vars, con `mkApps` es)]
604       where
605         [tup_con] = tyConDataCons tup_tc  
606
607     from_comp expr (Keep _ _) = return expr
608     from_comp expr (Wrap ty)
609       = do
610           wrap <- builtin wrapTyCon
611           return $ unwrapNewTypeBody wrap [ty] expr
612
613
614 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
615 buildToArrPRepr vect_tc prepr_tc pdata_tc r
616   = do
617       arg_ty <- mkPDataType el_ty
618       res_ty <- mkPDataType =<< mkPReprType el_ty
619       arg    <- newLocalVar (fsLit "xs") arg_ty
620
621       pdata_co <- mkBuiltinCo pdataTyCon
622       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
623           co           = mkAppCoercion pdata_co
624                        . mkSymCoercion
625                        $ mkTyConApp repr_co ty_args
626
627           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
628
629       (vars, result) <- to_sum r
630
631       return . Lam arg
632              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
633                [(DataAlt pdata_dc, vars, mkCoerce co result)]
634   where
635     ty_args = mkTyVarTys $ tyConTyVars vect_tc
636     el_ty   = mkTyConApp vect_tc ty_args
637
638     [pdata_dc] = tyConDataCons pdata_tc
639
640
641     to_sum EmptySum = do
642                         pvoid <- builtin pvoidVar
643                         return ([], Var pvoid)
644     to_sum (UnarySum r) = to_con r
645     to_sum (Sum { repr_psum_tc = psum_tc
646                 , repr_sel_ty  = sel_ty
647                 , repr_con_tys = tys
648                 , repr_cons    = cons
649                 })
650       = do
651           (vars, exprs) <- mapAndUnzipM to_con cons
652           sel <- newLocalVar (fsLit "sel") sel_ty
653           return (sel : concat vars, mk_result (Var sel) exprs)
654       where
655         [psum_con] = tyConDataCons psum_tc
656         mk_result sel exprs = wrapFamInstBody psum_tc tys
657                             $ mkConApp psum_con
658                             $ map Type tys ++ (sel : exprs)
659
660     to_con (ConRepr _ r) = to_prod r
661
662     to_prod EmptyProd = do
663                           pvoid <- builtin pvoidVar
664                           return ([], Var pvoid)
665     to_prod (UnaryProd r)
666       = do
667           pty  <- mkPDataType (compOrigType r)
668           var  <- newLocalVar (fsLit "x") pty
669           expr <- to_comp (Var var) r
670           return ([var], expr)
671
672     to_prod (Prod { repr_ptup_tc  = ptup_tc
673                   , repr_comp_tys = tys
674                   , repr_comps    = comps })
675       = do
676           ptys <- mapM (mkPDataType . compOrigType) comps
677           vars <- newLocalVars (fsLit "x") ptys
678           es   <- zipWithM to_comp (map Var vars) comps
679           return (vars, mk_result es)
680       where
681         [ptup_con] = tyConDataCons ptup_tc
682         mk_result exprs = wrapFamInstBody ptup_tc tys
683                         $ mkConApp ptup_con
684                         $ map Type tys ++ exprs
685
686     to_comp expr (Keep _ _) = return expr
687
688     -- FIXME: this is bound to be wrong!
689     to_comp expr (Wrap ty)
690       = do
691           wrap_tc  <- builtin wrapTyCon
692           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
693           return $ wrapNewTypeBody pwrap_tc [ty] expr
694
695
696 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
697 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
698   = do
699       arg_ty <- mkPDataType =<< mkPReprType el_ty
700       res_ty <- mkPDataType el_ty
701       arg    <- newLocalVar (fsLit "xs") arg_ty
702
703       pdata_co <- mkBuiltinCo pdataTyCon
704       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
705           co           = mkAppCoercion pdata_co
706                        $ mkTyConApp repr_co var_tys
707
708           scrut  = mkCoerce co (Var arg)
709
710           mk_result args = wrapFamInstBody pdata_tc var_tys
711                          $ mkConApp pdata_con
712                          $ map Type var_tys ++ args
713
714       (expr, _) <- fixV $ \ ~(_, args) ->
715                      from_sum res_ty (mk_result args) scrut r
716
717       return $ Lam arg expr
718     
719       -- (args, mk) <- from_sum res_ty scrut r
720       
721       -- let result = wrapFamInstBody pdata_tc var_tys
722       --           . mkConApp pdata_dc
723       --           $ map Type var_tys ++ args
724
725       -- return $ Lam arg (mk result)
726   where
727     var_tys = mkTyVarTys $ tyConTyVars vect_tc
728     el_ty   = mkTyConApp vect_tc var_tys
729
730     [pdata_con] = tyConDataCons pdata_tc
731
732     from_sum _ res _ EmptySum = return (res, [])
733     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
734     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
735                                   , repr_sel_ty  = sel_ty
736                                   , repr_con_tys = tys
737                                   , repr_cons    = cons })
738       = do
739           sel  <- newLocalVar (fsLit "sel") sel_ty
740           ptys <- mapM mkPDataType tys
741           vars <- newLocalVars (fsLit "xs") ptys
742           (res', args) <- fold from_con res_ty res (map Var vars) cons
743           let scrut = unwrapFamInstScrut psum_tc tys expr
744               body  = mkWildCase scrut (exprType scrut) res_ty
745                       [(DataAlt psum_con, sel : vars, res')]
746           return (body, Var sel : args)
747       where
748         [psum_con] = tyConDataCons psum_tc
749
750
751     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
752
753     from_prod _ res _ EmptyProd = return (res, [])
754     from_prod res_ty res expr (UnaryProd r)
755       = from_comp res_ty res expr r
756     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
757                                     , repr_comp_tys = tys
758                                     , repr_comps    = comps })
759       = do
760           ptys <- mapM mkPDataType tys
761           vars <- newLocalVars (fsLit "ys") ptys
762           (res', args) <- fold from_comp res_ty res (map Var vars) comps
763           let scrut = unwrapFamInstScrut ptup_tc tys expr
764               body  = mkWildCase scrut (exprType scrut) res_ty
765                       [(DataAlt ptup_con, vars, res')]
766           return (body, args)
767       where
768         [ptup_con] = tyConDataCons ptup_tc
769
770     from_comp _ res expr (Keep _ _) = return (res, [expr])
771     from_comp _ res expr (Wrap ty)
772       = do
773           wrap_tc  <- builtin wrapTyCon
774           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
775           return (res, [unwrapNewTypeBody pwrap_tc [ty]
776                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
777
778     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
779       where
780         f' (expr, r) (res, args) = do
781                                      (res', args') <- f res_ty res expr r
782                                      return (res', args' ++ args)
783
784 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
785 buildPRDict vect_tc prepr_tc _ r
786   = do
787       dict <- sum_dict r
788       pr_co <- mkBuiltinCo prTyCon
789       let co = mkAppCoercion pr_co
790              . mkSymCoercion
791              $ mkTyConApp arg_co ty_args
792       return (mkCoerce co dict)
793   where
794     ty_args = mkTyVarTys (tyConTyVars vect_tc)
795     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
796
797     sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
798     sum_dict (UnarySum r) = con_dict r
799     sum_dict (Sum { repr_sum_tc  = sum_tc
800                   , repr_con_tys = tys
801                   , repr_cons    = cons
802                   })
803       = do
804           dicts <- mapM con_dict cons
805           dfun  <- prDFunOfTyCon sum_tc
806           return $ dfun `mkTyApps` tys `mkApps` dicts
807
808     con_dict (ConRepr _ r) = prod_dict r
809
810     prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
811     prod_dict (UnaryProd r) = comp_dict r
812     prod_dict (Prod { repr_tup_tc   = tup_tc
813                     , repr_comp_tys = tys
814                     , repr_comps    = comps })
815       = do
816           dicts <- mapM comp_dict comps
817           dfun <- prDFunOfTyCon tup_tc
818           return $ dfun `mkTyApps` tys `mkApps` dicts
819
820     comp_dict (Keep _ pr) = return pr
821     comp_dict (Wrap ty)   = wrapPR ty
822
823
824 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
825 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
826   do
827     name' <- cloneName mkPDataTyConOcc orig_name
828     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
829     pdata <- builtin pdataTyCon
830
831     liftDs $ buildAlgTyCon name'
832                            tyvars
833                            []          -- no stupid theta
834                            rhs
835                            rec_flag    -- FIXME: is this ok?
836                            False       -- FIXME: no generics
837                            False       -- not GADT syntax
838                            (Just $ mk_fam_inst pdata vect_tc)
839   where
840     orig_name = tyConName orig_tc
841     tyvars = tyConTyVars vect_tc
842     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
843
844
845 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
846 buildPDataTyConRhs orig_name vect_tc repr_tc repr
847   = do
848       data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
849       return $ DataTyCon { data_cons = [data_con], is_enum = False }
850
851 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
852 buildPDataDataCon orig_name vect_tc repr_tc repr
853   = do
854       dc_name  <- cloneName mkPDataDataConOcc orig_name
855       comp_tys <- sum_tys repr
856
857       liftDs $ buildDataCon dc_name
858                             False                  -- not infix
859                             (map (const HsNoBang) comp_tys)
860                             []                     -- no field labels
861                             tvs
862                             []                     -- no existentials
863                             []                     -- no eq spec
864                             []                     -- no context
865                             comp_tys
866                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
867                             repr_tc
868   where
869     tvs   = tyConTyVars vect_tc
870
871     sum_tys EmptySum = return []
872     sum_tys (UnarySum r) = con_tys r
873     sum_tys (Sum { repr_sel_ty = sel_ty
874                  , repr_cons   = cons })
875       = liftM (sel_ty :) (concatMapM con_tys cons)
876
877     con_tys (ConRepr _ r) = prod_tys r
878
879     prod_tys EmptyProd = return []
880     prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
881     prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
882
883     comp_ty r = mkPDataType (compOrigType r)
884
885
886 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
887                    -> VM Var
888 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
889   = do
890       vectDataConWorkers orig_tc vect_tc pdata_tc
891       buildPADict vect_tc prepr_tc pdata_tc repr
892
893 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
894 vectDataConWorkers orig_tc vect_tc arr_tc
895   = do
896       bs <- sequence
897           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
898           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
899                                  rep_tys
900                                  (inits rep_tys)
901                                  (tail $ tails rep_tys)
902       mapM_ (uncurry hoistBinding) bs
903   where
904     tyvars   = tyConTyVars vect_tc
905     var_tys  = mkTyVarTys tyvars
906     ty_args  = map Type var_tys
907     res_ty   = mkTyConApp vect_tc var_tys
908
909     cons     = tyConDataCons vect_tc
910     arity    = length cons
911     [arr_dc] = tyConDataCons arr_tc
912
913     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
914
915
916     mk_data_con con tys pre post
917       = liftM2 (,) (vect_data_con con)
918                    (lift_data_con tys pre post (mkDataConTag con))
919
920     sel_replicate len tag
921       | arity > 1 = do
922                       rep <- builtin (selReplicate arity)
923                       return [rep `mkApps` [len, tag]]
924
925       | otherwise = return []
926
927     vect_data_con con = return $ mkConApp con ty_args
928     lift_data_con tys pre_tys post_tys tag
929       = do
930           len  <- builtin liftingContext
931           args <- mapM (newLocalVar (fsLit "xs"))
932                   =<< mapM mkPDataType tys
933
934           sel  <- sel_replicate (Var len) tag
935
936           pre   <- mapM emptyPD (concat pre_tys)
937           post  <- mapM emptyPD (concat post_tys)
938
939           return . mkLams (len : args)
940                  . wrapFamInstBody arr_tc var_tys
941                  . mkConApp arr_dc
942                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
943
944     def_worker data_con arg_tys mk_body
945       = do
946           arity <- polyArity tyvars
947           body <- closedV
948                 . inBind orig_worker
949                 . polyAbstract tyvars $ \args ->
950                   liftM (mkLams (tyvars ++ args) . vectorised)
951                 $ buildClosures tyvars [] arg_tys res_ty mk_body
952
953           raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
954           let vect_worker = raw_worker `setIdUnfolding`
955                               mkInlineRule body (Just arity)
956           defGlobalVar orig_worker vect_worker
957           return (vect_worker, body)
958       where
959         orig_worker = dataConWorkId data_con
960
961 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
962 buildPADict vect_tc prepr_tc arr_tc repr
963   = polyAbstract tvs $ \args ->
964     do
965       method_ids <- mapM (method args) paMethods
966
967       pa_tc  <- builtin paTyCon
968       pa_dc  <- builtin paDataCon
969       let dict = mkLams (tvs ++ args)
970                $ mkConApp pa_dc
971                $ Type inst_ty : map (method_call args) method_ids
972
973           dfun_ty = mkForAllTys tvs
974                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
975
976       raw_dfun <- newExportedVar dfun_name dfun_ty
977       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
978                           `setInlinePragma` dfunInlinePragma
979
980       hoistBinding dfun dict
981       return dfun
982   where
983     tvs = tyConTyVars vect_tc
984     arg_tys = mkTyVarTys tvs
985     inst_ty = mkTyConApp vect_tc arg_tys
986
987     dfun_name = mkPADFunOcc (getOccName vect_tc)
988
989     method args (name, build)
990       = localV
991       $ do
992           expr <- build vect_tc prepr_tc arr_tc repr
993           let body = mkLams (tvs ++ args) expr
994           raw_var <- newExportedVar (method_name name) (exprType body)
995           let var = raw_var
996                       `setIdUnfolding` mkInlineRule body (Just (length args))
997                       `setInlinePragma` alwaysInlinePragma
998           hoistBinding var body
999           return var
1000
1001     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1002
1003     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1004
1005
1006 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1007 paMethods = [("dictPRepr",    buildPRDict),
1008              ("toPRepr",      buildToPRepr),
1009              ("fromPRepr",    buildFromPRepr),
1010              ("toArrPRepr",   buildToArrPRepr),
1011              ("fromArrPRepr", buildFromArrPRepr)]
1012
1013
1014 -- | Split the given tycons into two sets depending on whether they have to be
1015 --   converted (first list) or not (second list). The first argument contains
1016 --   information about the conversion status of external tycons:
1017 --
1018 --   * tycons which have converted versions are mapped to True
1019 --   * tycons which are not changed by vectorisation are mapped to False
1020 --   * tycons which can't be converted are not elements of the map
1021 --
1022 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1023 classifyTyCons = classify [] []
1024   where
1025     classify conv keep _  [] = (conv, keep)
1026     classify conv keep cs ((tcs, ds) : rs)
1027       | can_convert && must_convert
1028         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1029       | can_convert
1030         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1031       | otherwise
1032         = classify conv keep cs rs
1033       where
1034         refs = ds `delListFromUniqSet` tcs
1035
1036         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1037         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1038
1039         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1040
1041 -- | Compute mutually recursive groups of tycons in topological order
1042 --
1043 tyConGroups :: [TyCon] -> [TyConGroup]
1044 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1045   where
1046     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1047                                 , let ds = tyConsOfTyCon tc]
1048
1049     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1050     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
1051       where
1052         (tcs, dss) = unzip els
1053
1054 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1055 tyConsOfTyCon
1056   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1057
1058 tyConsOfType :: Type -> UniqSet TyCon
1059 tyConsOfType ty
1060   | Just ty' <- coreView ty    = tyConsOfType ty'
1061 tyConsOfType (TyVarTy _)       = emptyUniqSet
1062 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1063   where
1064     extend | isUnLiftedTyCon tc
1065            || isTupleTyCon   tc = id
1066
1067            | otherwise          = (`addOneToUniqSet` tc)
1068
1069 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
1070 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1071                                  `addOneToUniqSet` funTyCon
1072 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
1073 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1074
1075 tyConsOfTypes :: [Type] -> UniqSet TyCon
1076 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1077
1078
1079 -- ----------------------------------------------------------------------------
1080 -- Conversions
1081
1082 -- | Build an expression that calls the vectorised version of some 
1083 --   function from a `Closure`.
1084 --
1085 --   For example
1086 --   @   
1087 --      \(x :: Double) -> 
1088 --      \(y :: Double) -> 
1089 --      ($v_foo $: x) $: y
1090 --   @
1091 --
1092 --   We use the type of the original binding to work out how many
1093 --   outer lambdas to add.
1094 --
1095 fromVect 
1096         :: Type         -- ^ The type of the original binding.
1097         -> CoreExpr     -- ^ Expression giving the closure to use, eg @$v_foo@.
1098         -> VM CoreExpr
1099         
1100 -- Convert the type to the core view if it isn't already.
1101 fromVect ty expr 
1102         | Just ty' <- coreView ty 
1103         = fromVect ty' expr
1104
1105 -- For each function constructor in the original type we add an outer 
1106 -- lambda to bind the parameter variable, and an inner application of it.
1107 fromVect (FunTy arg_ty res_ty) expr
1108   = do
1109       arg     <- newLocalVar (fsLit "x") arg_ty
1110       varg    <- toVect arg_ty (Var arg)
1111       varg_ty <- vectType arg_ty
1112       vres_ty <- vectType res_ty
1113       apply   <- builtin applyVar
1114       body    <- fromVect res_ty
1115                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1116       return $ Lam arg body
1117
1118 -- If the type isn't a function then it's time to call on the closure.
1119 fromVect ty expr
1120   = identityConv ty >> return expr
1121
1122
1123 toVect :: Type -> CoreExpr -> VM CoreExpr
1124 toVect ty expr = identityConv ty >> return expr
1125
1126
1127 identityConv :: Type -> VM ()
1128 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1129 identityConv (TyConApp tycon tys)
1130   = do
1131       mapM_ identityConv tys
1132       identityConvTyCon tycon
1133 identityConv _ = noV
1134
1135 identityConvTyCon :: TyCon -> VM ()
1136 identityConvTyCon tc
1137   | isBoxedTupleTyCon tc = return ()
1138   | isUnLiftedTyCon tc   = return ()
1139   | otherwise            = do
1140                              tc' <- maybeV (lookupTyCon tc)
1141                              if tc == tc' then return () else noV
1142