1 {-# OPTIONS -fno-warn-missing-signatures #-}
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4 -- arrSumArity, pdataCompTys, pdataCompVars,
13 import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
18 import MkCore ( mkWildCase )
26 import FamInstEnv ( FamInst, mkLocalFamInst )
30 import Var ( Var, TyVar, varType, varName )
31 import Name ( Name, getOccName )
38 import Digraph ( SCC(..), stronglyConnCompFromEdgedVertices )
43 import MonadUtils ( zipWith3M, foldrM, concatMapM )
44 import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
49 dtrace s x = if debug then pprTrace "VectType" s x else x
51 -- ----------------------------------------------------------------------------
54 -- | Vectorise a type constructor.
55 vectTyCon :: TyCon -> VM TyCon
57 | isFunTyCon tc = builtin closureTyCon
58 | isBoxedTupleTyCon tc = return tc
59 | isUnLiftedTyCon tc = return tc
61 = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
65 vectAndLiftType :: Type -> VM (Type, Type)
66 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
69 mdicts <- mapM paDictArgType tyvars
70 let dicts = [dict | Just dict <- mdicts]
71 vmono_ty <- vectType mono_ty
72 lmono_ty <- mkPDataType vmono_ty
73 return (abstractType tyvars dicts vmono_ty,
74 abstractType tyvars dicts lmono_ty)
76 (tyvars, mono_ty) = splitForAllTys ty
79 -- | Vectorise a type.
80 vectType :: Type -> VM Type
82 | Just ty' <- coreView ty
85 vectType (TyVarTy tv) = return $ TyVarTy tv
86 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
87 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
88 vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
89 (mapM vectAndBoxType [ty1,ty2])
91 -- For each quantified var we need to add a PA dictionary out the front of the type.
92 -- So forall a. C a => a -> a
93 -- turns into forall a. Cv a => PA a => a :-> a
94 vectType ty@(ForAllTy _ _)
96 -- split the type into the quantified vars, its dictionaries and the body.
97 let (tyvars, tyBody) = splitForAllTys ty
98 let (tyArgs, tyResult) = splitFunTys tyBody
100 let (tyArgs_dict, tyArgs_regular)
101 = partition isDictType tyArgs
103 -- vectorise the body.
104 let tyBody' = mkFunTys tyArgs_regular tyResult
105 tyBody'' <- vectType tyBody'
107 -- vectorise the dictionary parameters.
108 dictsVect <- mapM vectType tyArgs_dict
110 -- make a PA dictionary for each of the type variables.
111 dictsPA <- liftM catMaybes $ mapM paDictArgType tyvars
113 -- pack it all back together.
114 return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
116 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
119 -- | Add quantified vars and dictionary parameters to the front of a type.
120 abstractType :: [TyVar] -> [Type] -> Type -> Type
121 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
124 -- | Check if some type is a type class dictionary.
125 isDictType :: Type -> Bool
127 = case splitTyConApp_maybe ty of
128 Just (tyCon, _) -> isClassTyCon tyCon
132 -- ----------------------------------------------------------------------------
135 boxType :: Type -> VM Type
137 | Just (tycon, []) <- splitTyConApp_maybe ty
138 , isUnLiftedTyCon tycon
140 r <- lookupBoxedTyCon tycon
142 Just tycon' -> return $ mkTyConApp tycon' []
145 boxType ty = return ty
147 vectAndBoxType :: Type -> VM Type
148 vectAndBoxType ty = vectType ty >>= boxType
151 -- ----------------------------------------------------------------------------
154 type TyConGroup = ([TyCon], UniqSet TyCon)
156 -- | Vectorise a type environment.
157 -- The type environment contains all the type things defined in a module.
158 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
162 cs <- readGEnv $ mk_map . global_tycons
164 -- Split the list of TyCons into the ones we have to vectorise vs the
165 -- ones we can pass through unchanged. We also pass through algebraic
166 -- types that use non Haskell98 features, as we don't handle those.
167 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
168 keep_dcs = concatMap tyConDataCons keep_tcs
170 dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
172 zipWithM_ defTyCon keep_tcs keep_tcs
173 zipWithM_ defDataCon keep_dcs keep_dcs
175 new_tcs <- vectTyConDecls conv_tcs
177 dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
179 let orig_tcs = keep_tcs ++ conv_tcs
181 -- We don't need to make new representation types for dictionary
182 -- constructors. The constructors are always fully applied, and we don't
183 -- need to lift them to arrays as a dictionary of a particular type
184 -- always has the same value.
185 let vect_tcs = filter (not . isClassTyCon)
186 $ keep_tcs ++ new_tcs
188 dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
190 mapM_ dumpTycon $ new_tcs
193 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
195 defTyConPAs (zipLazy vect_tcs dfuns')
196 reprs <- mapM tyConRepr vect_tcs
197 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
198 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
201 $ zipWith5 buildTyConBindings
209 return (dfuns, binds, repr_tcs ++ pdata_tcs)
211 let all_new_tcs = new_tcs ++ inst_tcs
213 let new_env = extendTypeEnvList env
214 (map ATyCon all_new_tcs
215 ++ [ADataCon dc | tc <- all_new_tcs
216 , dc <- tyConDataCons tc])
218 return (new_env, map mkLocalFamInst inst_tcs, binds)
220 tycons = typeEnvTyCons env
221 groups = tyConGroups tycons
223 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
226 -- | Vectorise some (possibly recursively defined) type constructors.
227 vectTyConDecls :: [TyCon] -> VM [TyCon]
228 vectTyConDecls tcs = fixV $ \tcs' ->
230 mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
231 mapM vectTyConDecl tcs
233 dumpTycon :: TyCon -> VM ()
235 | Just cls <- tyConClass_maybe tycon
236 = dtrace (vcat [ ppr tycon
237 , ppr [(m, varType m) | m <- classMethods cls ]])
244 -- | Vectorise a single type construcrtor.
245 vectTyConDecl :: TyCon -> VM TyCon
247 -- a type class constructor.
248 -- TODO: check for no stupid theta, fds, assoc types.
250 , Just cls <- tyConClass_maybe tycon
252 = do -- make the name of the vectorised class tycon.
253 name' <- cloneName mkVectTyConOcc (tyConName tycon)
255 -- vectorise right of definition.
256 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
258 -- vectorise method selectors.
259 -- This also adds a mapping between the original and vectorised method selector
261 methods' <- mapM vectMethod
262 $ [(id, defMethSpecOfDefMeth meth)
263 | (id, meth) <- classOpItems cls]
265 -- keep the original recursiveness flag.
266 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
268 -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
271 False -- include unfoldings on dictionary selectors.
272 name' -- new name V_T:Class
273 (tyConTyVars tycon) -- keep original type vars
274 [] -- no stupid theta
275 [] -- no functional dependencies
276 [] -- no associated types
277 methods' -- method info
278 rec_flag -- whether recursive
280 let tycon' = mkClassTyCon name'
289 -- a regular algebraic type constructor.
290 -- TODO: check for stupid theta, generaics, GADTS etc
292 = do name' <- cloneName mkVectTyConOcc (tyConName tycon)
293 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
294 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
296 liftDs $ buildAlgTyCon
298 (tyConTyVars tycon) -- keep original type vars.
299 [] -- no stupid theta.
300 rhs' -- new constructor defs.
301 rec_flag -- FIXME: is this ok?
302 False -- FIXME: no generics
303 False -- not GADT syntax
304 Nothing -- not a family instance
306 -- some other crazy thing that we don't handle.
308 = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
311 -- | Vectorise a class method.
312 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
313 vectMethod (id, defMeth)
315 -- Vectorise the method type.
316 typ' <- vectType (varType id)
318 -- Create a name for the vectorised method.
319 id' <- cloneId mkVectOcc id typ'
322 -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
323 -- to the types of each method. However, the types we get back from vectType
324 -- above already already have these, so we need to chop them off here otherwise
325 -- we'll get two copies in the final version.
326 let (_tyvars, tyBody) = splitForAllTys typ'
327 let (_dict, tyRest) = splitFunTy tyBody
329 return (Var.varName id', defMeth, tyRest)
332 -- | Vectorise the RHS of an algebraic type.
333 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
334 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
338 data_cons' <- mapM vectDataCon data_cons
339 zipWithM_ defDataCon data_cons data_cons'
340 return $ DataTyCon { data_cons = data_cons'
345 = cantVectorise "Can't vectorise type definition:" (ppr tc)
348 -- | Vectorise a data constructor.
349 -- Vectorises its argument and return types.
350 vectDataCon :: DataCon -> VM DataCon
352 | not . null $ dataConExTyVars dc
353 = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
355 | not . null $ dataConEqSpec dc
356 = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
360 name' <- cloneName mkVectDataConOcc name
361 tycon' <- vectTyCon tycon
362 arg_tys <- mapM vectType rep_arg_tys
364 liftDs $ buildDataCon
367 (map (const HsNoBang) arg_tys) -- strictness annots on args.
368 [] -- no labelled fields
369 univ_tvs -- universally quantified vars
370 [] -- no existential tvs for now
371 [] -- no eq spec for now
373 arg_tys -- argument types
374 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
375 tycon' -- representation tycon
377 name = dataConName dc
378 univ_tvs = dataConUnivTyVars dc
379 rep_arg_tys = dataConRepArgTys dc
380 tycon = dataConTyCon dc
382 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
383 mk_fam_inst fam_tc arg_tc
384 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
387 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
388 buildPReprTyCon orig_tc vect_tc repr
390 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
391 -- rhs_ty <- buildPReprType vect_tc
392 rhs_ty <- sumReprType repr
393 prepr_tc <- builtin preprTyCon
394 liftDs $ buildSynTyCon name
396 (SynonymTyCon rhs_ty)
398 (Just $ mk_fam_inst prepr_tc vect_tc)
400 tyvars = tyConTyVars vect_tc
402 data CompRepr = Keep Type
403 CoreExpr -- PR dictionary for the type
406 data ProdRepr = EmptyProd
408 | Prod { repr_tup_tc :: TyCon -- representation tuple tycon
409 , repr_ptup_tc :: TyCon -- PData representation tycon
410 , repr_comp_tys :: [Type] -- representation types of
411 , repr_comps :: [CompRepr] -- components
413 data ConRepr = ConRepr DataCon ProdRepr
415 data SumRepr = EmptySum
417 | Sum { repr_sum_tc :: TyCon -- representation sum tycon
418 , repr_psum_tc :: TyCon -- PData representation tycon
419 , repr_sel_ty :: Type -- type of selector
420 , repr_con_tys :: [Type] -- representation types of
421 , repr_cons :: [ConRepr] -- components
424 tyConRepr :: TyCon -> VM SumRepr
425 tyConRepr tc = sum_repr (tyConDataCons tc)
427 sum_repr [] = return EmptySum
428 sum_repr [con] = liftM UnarySum (con_repr con)
430 rs <- mapM con_repr cons
431 sum_tc <- builtin (sumTyCon arity)
432 tys <- mapM conReprType rs
433 (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
434 sel_ty <- builtin (selTy arity)
435 return $ Sum { repr_sum_tc = sum_tc
436 , repr_psum_tc = psum_tc
437 , repr_sel_ty = sel_ty
444 con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
446 prod_repr [] = return EmptyProd
447 prod_repr [ty] = liftM UnaryProd (comp_repr ty)
449 rs <- mapM comp_repr tys
450 tup_tc <- builtin (prodTyCon arity)
451 tys' <- mapM compReprType rs
452 (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
453 return $ Prod { repr_tup_tc = tup_tc
454 , repr_ptup_tc = ptup_tc
455 , repr_comp_tys = tys'
461 comp_repr ty = liftM (Keep ty) (prDictOfType ty)
462 `orElseV` return (Wrap ty)
464 sumReprType :: SumRepr -> VM Type
465 sumReprType EmptySum = voidType
466 sumReprType (UnarySum r) = conReprType r
467 sumReprType (Sum { repr_sum_tc = sum_tc, repr_con_tys = tys })
468 = return $ mkTyConApp sum_tc tys
470 conReprType :: ConRepr -> VM Type
471 conReprType (ConRepr _ r) = prodReprType r
473 prodReprType :: ProdRepr -> VM Type
474 prodReprType EmptyProd = voidType
475 prodReprType (UnaryProd r) = compReprType r
476 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
477 = return $ mkTyConApp tup_tc tys
479 compReprType :: CompRepr -> VM Type
480 compReprType (Keep ty _) = return ty
481 compReprType (Wrap ty) = do
482 wrap_tc <- builtin wrapTyCon
483 return $ mkTyConApp wrap_tc [ty]
485 compOrigType :: CompRepr -> Type
486 compOrigType (Keep ty _) = ty
487 compOrigType (Wrap ty) = ty
489 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
490 buildToPRepr vect_tc repr_tc _ repr
492 let arg_ty = mkTyConApp vect_tc ty_args
493 res_ty <- mkPReprType arg_ty
494 arg <- newLocalVar (fsLit "x") arg_ty
495 result <- to_sum (Var arg) arg_ty res_ty repr
496 return $ Lam arg result
498 ty_args = mkTyVarTys (tyConTyVars vect_tc)
500 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
502 to_sum _ _ _ EmptySum
504 void <- builtin voidVar
505 return $ wrap_repr_inst $ Var void
507 to_sum arg arg_ty res_ty (UnarySum r)
509 (pat, vars, body) <- con_alt r
510 return $ mkWildCase arg arg_ty res_ty
511 [(pat, vars, wrap_repr_inst body)]
513 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
515 , repr_cons = cons })
517 alts <- mapM con_alt cons
518 let alts' = [(pat, vars, wrap_repr_inst
519 $ mkConApp sum_con (map Type tys ++ [body]))
520 | ((pat, vars, body), sum_con)
521 <- zip alts (tyConDataCons sum_tc)]
522 return $ mkWildCase arg arg_ty res_ty alts'
524 con_alt (ConRepr con r)
526 (vars, body) <- to_prod r
527 return (DataAlt con, vars, body)
531 void <- builtin voidVar
532 return ([], Var void)
534 to_prod (UnaryProd comp)
536 var <- newLocalVar (fsLit "x") (compOrigType comp)
537 body <- to_comp (Var var) comp
540 to_prod(Prod { repr_tup_tc = tup_tc
541 , repr_comp_tys = tys
542 , repr_comps = comps })
544 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
545 exprs <- zipWithM to_comp (map Var vars) comps
546 return (vars, mkConApp tup_con (map Type tys ++ exprs))
548 [tup_con] = tyConDataCons tup_tc
550 to_comp expr (Keep _ _) = return expr
551 to_comp expr (Wrap ty) = do
552 wrap_tc <- builtin wrapTyCon
553 return $ wrapNewTypeBody wrap_tc [ty] expr
556 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
557 buildFromPRepr vect_tc repr_tc _ repr
559 arg_ty <- mkPReprType res_ty
560 arg <- newLocalVar (fsLit "x") arg_ty
562 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
564 return $ Lam arg result
566 ty_args = mkTyVarTys (tyConTyVars vect_tc)
567 res_ty = mkTyConApp vect_tc ty_args
571 dummy <- builtin fromVoidVar
572 return $ Var dummy `App` Type res_ty
574 from_sum expr (UnarySum r) = from_con expr r
575 from_sum expr (Sum { repr_sum_tc = sum_tc
577 , repr_cons = cons })
579 vars <- newLocalVars (fsLit "x") tys
580 es <- zipWithM from_con (map Var vars) cons
581 return $ mkWildCase expr (exprType expr) res_ty
582 [(DataAlt con, [var], e)
583 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
585 from_con expr (ConRepr con r)
586 = from_prod expr (mkConApp con $ map Type ty_args) r
588 from_prod _ con EmptyProd = return con
589 from_prod expr con (UnaryProd r)
591 e <- from_comp expr r
594 from_prod expr con (Prod { repr_tup_tc = tup_tc
595 , repr_comp_tys = tys
599 vars <- newLocalVars (fsLit "y") tys
600 es <- zipWithM from_comp (map Var vars) comps
601 return $ mkWildCase expr (exprType expr) res_ty
602 [(DataAlt tup_con, vars, con `mkApps` es)]
604 [tup_con] = tyConDataCons tup_tc
606 from_comp expr (Keep _ _) = return expr
607 from_comp expr (Wrap ty)
609 wrap <- builtin wrapTyCon
610 return $ unwrapNewTypeBody wrap [ty] expr
613 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
614 buildToArrPRepr vect_tc prepr_tc pdata_tc r
616 arg_ty <- mkPDataType el_ty
617 res_ty <- mkPDataType =<< mkPReprType el_ty
618 arg <- newLocalVar (fsLit "xs") arg_ty
620 pdata_co <- mkBuiltinCo pdataTyCon
621 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
622 co = mkAppCoercion pdata_co
624 $ mkTyConApp repr_co ty_args
626 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
628 (vars, result) <- to_sum r
631 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
632 [(DataAlt pdata_dc, vars, mkCoerce co result)]
634 ty_args = mkTyVarTys $ tyConTyVars vect_tc
635 el_ty = mkTyConApp vect_tc ty_args
637 [pdata_dc] = tyConDataCons pdata_tc
641 pvoid <- builtin pvoidVar
642 return ([], Var pvoid)
643 to_sum (UnarySum r) = to_con r
644 to_sum (Sum { repr_psum_tc = psum_tc
645 , repr_sel_ty = sel_ty
650 (vars, exprs) <- mapAndUnzipM to_con cons
651 sel <- newLocalVar (fsLit "sel") sel_ty
652 return (sel : concat vars, mk_result (Var sel) exprs)
654 [psum_con] = tyConDataCons psum_tc
655 mk_result sel exprs = wrapFamInstBody psum_tc tys
657 $ map Type tys ++ (sel : exprs)
659 to_con (ConRepr _ r) = to_prod r
661 to_prod EmptyProd = do
662 pvoid <- builtin pvoidVar
663 return ([], Var pvoid)
664 to_prod (UnaryProd r)
666 pty <- mkPDataType (compOrigType r)
667 var <- newLocalVar (fsLit "x") pty
668 expr <- to_comp (Var var) r
671 to_prod (Prod { repr_ptup_tc = ptup_tc
672 , repr_comp_tys = tys
673 , repr_comps = comps })
675 ptys <- mapM (mkPDataType . compOrigType) comps
676 vars <- newLocalVars (fsLit "x") ptys
677 es <- zipWithM to_comp (map Var vars) comps
678 return (vars, mk_result es)
680 [ptup_con] = tyConDataCons ptup_tc
681 mk_result exprs = wrapFamInstBody ptup_tc tys
683 $ map Type tys ++ exprs
685 to_comp expr (Keep _ _) = return expr
687 -- FIXME: this is bound to be wrong!
688 to_comp expr (Wrap ty)
690 wrap_tc <- builtin wrapTyCon
691 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
692 return $ wrapNewTypeBody pwrap_tc [ty] expr
695 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
696 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
698 arg_ty <- mkPDataType =<< mkPReprType el_ty
699 res_ty <- mkPDataType el_ty
700 arg <- newLocalVar (fsLit "xs") arg_ty
702 pdata_co <- mkBuiltinCo pdataTyCon
703 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
704 co = mkAppCoercion pdata_co
705 $ mkTyConApp repr_co var_tys
707 scrut = mkCoerce co (Var arg)
709 mk_result args = wrapFamInstBody pdata_tc var_tys
711 $ map Type var_tys ++ args
713 (expr, _) <- fixV $ \ ~(_, args) ->
714 from_sum res_ty (mk_result args) scrut r
716 return $ Lam arg expr
718 -- (args, mk) <- from_sum res_ty scrut r
720 -- let result = wrapFamInstBody pdata_tc var_tys
721 -- . mkConApp pdata_dc
722 -- $ map Type var_tys ++ args
724 -- return $ Lam arg (mk result)
726 var_tys = mkTyVarTys $ tyConTyVars vect_tc
727 el_ty = mkTyConApp vect_tc var_tys
729 [pdata_con] = tyConDataCons pdata_tc
731 from_sum _ res _ EmptySum = return (res, [])
732 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
733 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
734 , repr_sel_ty = sel_ty
736 , repr_cons = cons })
738 sel <- newLocalVar (fsLit "sel") sel_ty
739 ptys <- mapM mkPDataType tys
740 vars <- newLocalVars (fsLit "xs") ptys
741 (res', args) <- fold from_con res_ty res (map Var vars) cons
742 let scrut = unwrapFamInstScrut psum_tc tys expr
743 body = mkWildCase scrut (exprType scrut) res_ty
744 [(DataAlt psum_con, sel : vars, res')]
745 return (body, Var sel : args)
747 [psum_con] = tyConDataCons psum_tc
750 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
752 from_prod _ res _ EmptyProd = return (res, [])
753 from_prod res_ty res expr (UnaryProd r)
754 = from_comp res_ty res expr r
755 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
756 , repr_comp_tys = tys
757 , repr_comps = comps })
759 ptys <- mapM mkPDataType tys
760 vars <- newLocalVars (fsLit "ys") ptys
761 (res', args) <- fold from_comp res_ty res (map Var vars) comps
762 let scrut = unwrapFamInstScrut ptup_tc tys expr
763 body = mkWildCase scrut (exprType scrut) res_ty
764 [(DataAlt ptup_con, vars, res')]
767 [ptup_con] = tyConDataCons ptup_tc
769 from_comp _ res expr (Keep _ _) = return (res, [expr])
770 from_comp _ res expr (Wrap ty)
772 wrap_tc <- builtin wrapTyCon
773 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
774 return (res, [unwrapNewTypeBody pwrap_tc [ty]
775 $ unwrapFamInstScrut pwrap_tc [ty] expr])
777 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
779 f' (expr, r) (res, args) = do
780 (res', args') <- f res_ty res expr r
781 return (res', args' ++ args)
783 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
784 buildPRDict vect_tc prepr_tc _ r
787 pr_co <- mkBuiltinCo prTyCon
788 let co = mkAppCoercion pr_co
790 $ mkTyConApp arg_co ty_args
791 return (mkCoerce co dict)
793 ty_args = mkTyVarTys (tyConTyVars vect_tc)
794 Just arg_co = tyConFamilyCoercion_maybe prepr_tc
796 sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
797 sum_dict (UnarySum r) = con_dict r
798 sum_dict (Sum { repr_sum_tc = sum_tc
803 dicts <- mapM con_dict cons
804 dfun <- prDFunOfTyCon sum_tc
805 return $ dfun `mkTyApps` tys `mkApps` dicts
807 con_dict (ConRepr _ r) = prod_dict r
809 prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
810 prod_dict (UnaryProd r) = comp_dict r
811 prod_dict (Prod { repr_tup_tc = tup_tc
812 , repr_comp_tys = tys
813 , repr_comps = comps })
815 dicts <- mapM comp_dict comps
816 dfun <- prDFunOfTyCon tup_tc
817 return $ dfun `mkTyApps` tys `mkApps` dicts
819 comp_dict (Keep _ pr) = return pr
820 comp_dict (Wrap ty) = wrapPR ty
823 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
824 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
826 name' <- cloneName mkPDataTyConOcc orig_name
827 rhs <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
828 pdata <- builtin pdataTyCon
830 liftDs $ buildAlgTyCon name'
832 [] -- no stupid theta
834 rec_flag -- FIXME: is this ok?
835 False -- FIXME: no generics
836 False -- not GADT syntax
837 (Just $ mk_fam_inst pdata vect_tc)
839 orig_name = tyConName orig_tc
840 tyvars = tyConTyVars vect_tc
841 rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
844 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
845 buildPDataTyConRhs orig_name vect_tc repr_tc repr
847 data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
848 return $ DataTyCon { data_cons = [data_con], is_enum = False }
850 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
851 buildPDataDataCon orig_name vect_tc repr_tc repr
853 dc_name <- cloneName mkPDataDataConOcc orig_name
854 comp_tys <- sum_tys repr
856 liftDs $ buildDataCon dc_name
858 (map (const HsNoBang) comp_tys)
859 [] -- no field labels
861 [] -- no existentials
865 (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
868 tvs = tyConTyVars vect_tc
870 sum_tys EmptySum = return []
871 sum_tys (UnarySum r) = con_tys r
872 sum_tys (Sum { repr_sel_ty = sel_ty
873 , repr_cons = cons })
874 = liftM (sel_ty :) (concatMapM con_tys cons)
876 con_tys (ConRepr _ r) = prod_tys r
878 prod_tys EmptyProd = return []
879 prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
880 prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
882 comp_ty r = mkPDataType (compOrigType r)
885 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr
887 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
889 vectDataConWorkers orig_tc vect_tc pdata_tc
890 buildPADict vect_tc prepr_tc pdata_tc repr
892 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
893 vectDataConWorkers orig_tc vect_tc arr_tc
896 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
897 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
900 (tail $ tails rep_tys)
901 mapM_ (uncurry hoistBinding) bs
903 tyvars = tyConTyVars vect_tc
904 var_tys = mkTyVarTys tyvars
905 ty_args = map Type var_tys
906 res_ty = mkTyConApp vect_tc var_tys
908 cons = tyConDataCons vect_tc
910 [arr_dc] = tyConDataCons arr_tc
912 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
915 mk_data_con con tys pre post
916 = liftM2 (,) (vect_data_con con)
917 (lift_data_con tys pre post (mkDataConTag con))
919 sel_replicate len tag
921 rep <- builtin (selReplicate arity)
922 return [rep `mkApps` [len, tag]]
924 | otherwise = return []
926 vect_data_con con = return $ mkConApp con ty_args
927 lift_data_con tys pre_tys post_tys tag
929 len <- builtin liftingContext
930 args <- mapM (newLocalVar (fsLit "xs"))
931 =<< mapM mkPDataType tys
933 sel <- sel_replicate (Var len) tag
935 pre <- mapM emptyPD (concat pre_tys)
936 post <- mapM emptyPD (concat post_tys)
938 return . mkLams (len : args)
939 . wrapFamInstBody arr_tc var_tys
941 $ ty_args ++ sel ++ pre ++ map Var args ++ post
943 def_worker data_con arg_tys mk_body
945 arity <- polyArity tyvars
948 . polyAbstract tyvars $ \args ->
949 liftM (mkLams (tyvars ++ args) . vectorised)
950 $ buildClosures tyvars [] arg_tys res_ty mk_body
952 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
953 let vect_worker = raw_worker `setIdUnfolding`
954 mkInlineRule body (Just arity)
955 defGlobalVar orig_worker vect_worker
956 return (vect_worker, body)
958 orig_worker = dataConWorkId data_con
960 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
961 buildPADict vect_tc prepr_tc arr_tc repr
962 = polyAbstract tvs $ \args ->
964 method_ids <- mapM (method args) paMethods
966 pa_tc <- builtin paTyCon
967 pa_dc <- builtin paDataCon
968 let dict = mkLams (tvs ++ args)
970 $ Type inst_ty : map (method_call args) method_ids
972 dfun_ty = mkForAllTys tvs
973 $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
975 raw_dfun <- newExportedVar dfun_name dfun_ty
976 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
977 `setInlinePragma` dfunInlinePragma
979 hoistBinding dfun dict
982 tvs = tyConTyVars vect_tc
983 arg_tys = mkTyVarTys tvs
984 inst_ty = mkTyConApp vect_tc arg_tys
986 dfun_name = mkPADFunOcc (getOccName vect_tc)
988 method args (name, build)
991 expr <- build vect_tc prepr_tc arr_tc repr
992 let body = mkLams (tvs ++ args) expr
993 raw_var <- newExportedVar (method_name name) (exprType body)
995 `setIdUnfolding` mkInlineRule body (Just (length args))
996 `setInlinePragma` alwaysInlinePragma
997 hoistBinding var body
1000 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1002 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1005 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1006 paMethods = [("dictPRepr", buildPRDict),
1007 ("toPRepr", buildToPRepr),
1008 ("fromPRepr", buildFromPRepr),
1009 ("toArrPRepr", buildToArrPRepr),
1010 ("fromArrPRepr", buildFromArrPRepr)]
1013 -- | Split the given tycons into two sets depending on whether they have to be
1014 -- converted (first list) or not (second list). The first argument contains
1015 -- information about the conversion status of external tycons:
1017 -- * tycons which have converted versions are mapped to True
1018 -- * tycons which are not changed by vectorisation are mapped to False
1019 -- * tycons which can't be converted are not elements of the map
1021 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1022 classifyTyCons = classify [] []
1024 classify conv keep _ [] = (conv, keep)
1025 classify conv keep cs ((tcs, ds) : rs)
1026 | can_convert && must_convert
1027 = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1029 = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1031 = classify conv keep cs rs
1033 refs = ds `delListFromUniqSet` tcs
1035 can_convert = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1036 must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1038 convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1040 -- | Compute mutually recursive groups of tycons in topological order
1042 tyConGroups :: [TyCon] -> [TyConGroup]
1043 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1045 edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1046 , let ds = tyConsOfTyCon tc]
1048 mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1049 mk_grp (CyclicSCC els) = (tcs, unionManyUniqSets dss)
1051 (tcs, dss) = unzip els
1053 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1055 = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1057 tyConsOfType :: Type -> UniqSet TyCon
1059 | Just ty' <- coreView ty = tyConsOfType ty'
1060 tyConsOfType (TyVarTy _) = emptyUniqSet
1061 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1063 extend | isUnLiftedTyCon tc
1064 || isTupleTyCon tc = id
1066 | otherwise = (`addOneToUniqSet` tc)
1068 tyConsOfType (AppTy a b) = tyConsOfType a `unionUniqSets` tyConsOfType b
1069 tyConsOfType (FunTy a b) = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1070 `addOneToUniqSet` funTyCon
1071 tyConsOfType (ForAllTy _ ty) = tyConsOfType ty
1072 tyConsOfType other = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1074 tyConsOfTypes :: [Type] -> UniqSet TyCon
1075 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1078 -- ----------------------------------------------------------------------------
1081 -- | Build an expression that calls the vectorised version of some
1082 -- function from a `Closure`.
1086 -- \(x :: Double) ->
1087 -- \(y :: Double) ->
1088 -- ($v_foo $: x) $: y
1091 -- We use the type of the original binding to work out how many
1092 -- outer lambdas to add.
1095 :: Type -- ^ The type of the original binding.
1096 -> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
1099 -- Convert the type to the core view if it isn't already.
1101 | Just ty' <- coreView ty
1104 -- For each function constructor in the original type we add an outer
1105 -- lambda to bind the parameter variable, and an inner application of it.
1106 fromVect (FunTy arg_ty res_ty) expr
1108 arg <- newLocalVar (fsLit "x") arg_ty
1109 varg <- toVect arg_ty (Var arg)
1110 varg_ty <- vectType arg_ty
1111 vres_ty <- vectType res_ty
1112 apply <- builtin applyVar
1113 body <- fromVect res_ty
1114 $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1115 return $ Lam arg body
1117 -- If the type isn't a function then it's time to call on the closure.
1119 = identityConv ty >> return expr
1122 toVect :: Type -> CoreExpr -> VM CoreExpr
1123 toVect ty expr = identityConv ty >> return expr
1126 identityConv :: Type -> VM ()
1127 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1128 identityConv (TyConApp tycon tys)
1130 mapM_ identityConv tys
1131 identityConvTyCon tycon
1132 identityConv _ = noV
1134 identityConvTyCon :: TyCon -> VM ()
1135 identityConvTyCon tc
1136 | isBoxedTupleTyCon tc = return ()
1137 | isUnLiftedTyCon tc = return ()
1139 tc' <- maybeV (lookupTyCon tc)
1140 if tc == tc' then return () else noV