1 {-# OPTIONS -fno-warn-missing-signatures #-}
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4 -- arrSumArity, pdataCompTys, pdataCompVars,
14 import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
19 import MkCore ( mkWildCase )
27 import FamInstEnv ( FamInst, mkLocalFamInst )
31 import Var ( Var, TyVar, varType, varName )
32 import Name ( Name, getOccName )
39 import Digraph ( SCC(..), stronglyConnCompFromEdgedVertices )
44 import MonadUtils ( zipWith3M, foldrM, concatMapM )
45 import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
50 dtrace s x = if debug then pprTrace "VectType" s x else x
52 -- ----------------------------------------------------------------------------
55 -- | Vectorise a type constructor.
56 vectTyCon :: TyCon -> VM TyCon
58 | isFunTyCon tc = builtin closureTyCon
59 | isBoxedTupleTyCon tc = return tc
60 | isUnLiftedTyCon tc = return tc
62 = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
66 vectAndLiftType :: Type -> VM (Type, Type)
67 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
70 mdicts <- mapM paDictArgType tyvars
71 let dicts = [dict | Just dict <- mdicts]
72 vmono_ty <- vectType mono_ty
73 lmono_ty <- mkPDataType vmono_ty
74 return (abstractType tyvars dicts vmono_ty,
75 abstractType tyvars dicts lmono_ty)
77 (tyvars, mono_ty) = splitForAllTys ty
80 -- | Vectorise a type.
81 vectType :: Type -> VM Type
83 | Just ty' <- coreView ty
86 vectType (TyVarTy tv) = return $ TyVarTy tv
87 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
88 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
89 vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
90 (mapM vectAndBoxType [ty1,ty2])
92 -- For each quantified var we need to add a PA dictionary out the front of the type.
93 -- So forall a. C a => a -> a
94 -- turns into forall a. Cv a => PA a => a :-> a
95 vectType ty@(ForAllTy _ _)
97 -- split the type into the quantified vars, its dictionaries and the body.
98 let (tyvars, tyBody) = splitForAllTys ty
99 let (tyArgs, tyResult) = splitFunTys tyBody
101 let (tyArgs_dict, tyArgs_regular)
102 = partition isDictType tyArgs
104 -- vectorise the body.
105 let tyBody' = mkFunTys tyArgs_regular tyResult
106 tyBody'' <- vectType tyBody'
108 -- vectorise the dictionary parameters.
109 dictsVect <- mapM vectType tyArgs_dict
111 -- make a PA dictionary for each of the type variables.
112 dictsPA <- liftM catMaybes $ mapM paDictArgType tyvars
114 -- pack it all back together.
115 return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
117 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
120 -- | Add quantified vars and dictionary parameters to the front of a type.
121 abstractType :: [TyVar] -> [Type] -> Type -> Type
122 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
125 -- | Check if some type is a type class dictionary.
126 isDictType :: Type -> Bool
128 = case splitTyConApp_maybe ty of
129 Just (tyCon, _) -> isClassTyCon tyCon
133 -- ----------------------------------------------------------------------------
136 boxType :: Type -> VM Type
138 | Just (tycon, []) <- splitTyConApp_maybe ty
139 , isUnLiftedTyCon tycon
141 r <- lookupBoxedTyCon tycon
143 Just tycon' -> return $ mkTyConApp tycon' []
146 boxType ty = return ty
148 vectAndBoxType :: Type -> VM Type
149 vectAndBoxType ty = vectType ty >>= boxType
152 -- ----------------------------------------------------------------------------
155 type TyConGroup = ([TyCon], UniqSet TyCon)
157 -- | Vectorise a type environment.
158 -- The type environment contains all the type things defined in a module.
159 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
163 cs <- readGEnv $ mk_map . global_tycons
165 -- Split the list of TyCons into the ones we have to vectorise vs the
166 -- ones we can pass through unchanged. We also pass through algebraic
167 -- types that use non Haskell98 features, as we don't handle those.
168 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
169 keep_dcs = concatMap tyConDataCons keep_tcs
171 dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
173 zipWithM_ defTyCon keep_tcs keep_tcs
174 zipWithM_ defDataCon keep_dcs keep_dcs
176 new_tcs <- vectTyConDecls conv_tcs
178 dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
180 let orig_tcs = keep_tcs ++ conv_tcs
182 -- We don't need to make new representation types for dictionary
183 -- constructors. The constructors are always fully applied, and we don't
184 -- need to lift them to arrays as a dictionary of a particular type
185 -- always has the same value.
186 let vect_tcs = filter (not . isClassTyCon)
187 $ keep_tcs ++ new_tcs
189 dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
191 mapM_ dumpTycon $ new_tcs
194 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
196 defTyConPAs (zipLazy vect_tcs dfuns')
197 reprs <- mapM tyConRepr vect_tcs
198 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
199 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
202 $ zipWith5 buildTyConBindings
210 return (dfuns, binds, repr_tcs ++ pdata_tcs)
212 let all_new_tcs = new_tcs ++ inst_tcs
214 let new_env = extendTypeEnvList env
215 (map ATyCon all_new_tcs
216 ++ [ADataCon dc | tc <- all_new_tcs
217 , dc <- tyConDataCons tc])
219 return (new_env, map mkLocalFamInst inst_tcs, binds)
221 tycons = typeEnvTyCons env
222 groups = tyConGroups tycons
224 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
227 -- | Vectorise some (possibly recursively defined) type constructors.
228 vectTyConDecls :: [TyCon] -> VM [TyCon]
229 vectTyConDecls tcs = fixV $ \tcs' ->
231 mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
232 mapM vectTyConDecl tcs
234 dumpTycon :: TyCon -> VM ()
236 | Just cls <- tyConClass_maybe tycon
237 = dtrace (vcat [ ppr tycon
238 , ppr [(m, varType m) | m <- classMethods cls ]])
245 -- | Vectorise a single type construcrtor.
246 vectTyConDecl :: TyCon -> VM TyCon
248 -- a type class constructor.
249 -- TODO: check for no stupid theta, fds, assoc types.
251 , Just cls <- tyConClass_maybe tycon
253 = do -- make the name of the vectorised class tycon.
254 name' <- cloneName mkVectTyConOcc (tyConName tycon)
256 -- vectorise right of definition.
257 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
259 -- vectorise method selectors.
260 -- This also adds a mapping between the original and vectorised method selector
262 methods' <- mapM vectMethod
263 $ [(id, defMethSpecOfDefMeth meth)
264 | (id, meth) <- classOpItems cls]
266 -- keep the original recursiveness flag.
267 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
269 -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
272 False -- include unfoldings on dictionary selectors.
273 name' -- new name V_T:Class
274 (tyConTyVars tycon) -- keep original type vars
275 [] -- no stupid theta
276 [] -- no functional dependencies
277 [] -- no associated types
278 methods' -- method info
279 rec_flag -- whether recursive
281 let tycon' = mkClassTyCon name'
290 -- a regular algebraic type constructor.
291 -- TODO: check for stupid theta, generaics, GADTS etc
293 = do name' <- cloneName mkVectTyConOcc (tyConName tycon)
294 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
295 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
297 liftDs $ buildAlgTyCon
299 (tyConTyVars tycon) -- keep original type vars.
300 [] -- no stupid theta.
301 rhs' -- new constructor defs.
302 rec_flag -- FIXME: is this ok?
303 False -- FIXME: no generics
304 False -- not GADT syntax
305 Nothing -- not a family instance
307 -- some other crazy thing that we don't handle.
309 = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
312 -- | Vectorise a class method.
313 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
314 vectMethod (id, defMeth)
316 -- Vectorise the method type.
317 typ' <- vectType (varType id)
319 -- Create a name for the vectorised method.
320 id' <- cloneId mkVectOcc id typ'
323 -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
324 -- to the types of each method. However, the types we get back from vectType
325 -- above already already have these, so we need to chop them off here otherwise
326 -- we'll get two copies in the final version.
327 let (_tyvars, tyBody) = splitForAllTys typ'
328 let (_dict, tyRest) = splitFunTy tyBody
330 return (Var.varName id', defMeth, tyRest)
333 -- | Vectorise the RHS of an algebraic type.
334 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
335 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
339 data_cons' <- mapM vectDataCon data_cons
340 zipWithM_ defDataCon data_cons data_cons'
341 return $ DataTyCon { data_cons = data_cons'
346 = cantVectorise "Can't vectorise type definition:" (ppr tc)
349 -- | Vectorise a data constructor.
350 -- Vectorises its argument and return types.
351 vectDataCon :: DataCon -> VM DataCon
353 | not . null $ dataConExTyVars dc
354 = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
356 | not . null $ dataConEqSpec dc
357 = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
361 name' <- cloneName mkVectDataConOcc name
362 tycon' <- vectTyCon tycon
363 arg_tys <- mapM vectType rep_arg_tys
365 liftDs $ buildDataCon
368 (map (const HsNoBang) arg_tys) -- strictness annots on args.
369 [] -- no labelled fields
370 univ_tvs -- universally quantified vars
371 [] -- no existential tvs for now
372 [] -- no eq spec for now
374 arg_tys -- argument types
375 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
376 tycon' -- representation tycon
378 name = dataConName dc
379 univ_tvs = dataConUnivTyVars dc
380 rep_arg_tys = dataConRepArgTys dc
381 tycon = dataConTyCon dc
383 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
384 mk_fam_inst fam_tc arg_tc
385 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
388 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
389 buildPReprTyCon orig_tc vect_tc repr
391 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
392 -- rhs_ty <- buildPReprType vect_tc
393 rhs_ty <- sumReprType repr
394 prepr_tc <- builtin preprTyCon
395 liftDs $ buildSynTyCon name
397 (SynonymTyCon rhs_ty)
399 (Just $ mk_fam_inst prepr_tc vect_tc)
401 tyvars = tyConTyVars vect_tc
403 data CompRepr = Keep Type
404 CoreExpr -- PR dictionary for the type
407 data ProdRepr = EmptyProd
409 | Prod { repr_tup_tc :: TyCon -- representation tuple tycon
410 , repr_ptup_tc :: TyCon -- PData representation tycon
411 , repr_comp_tys :: [Type] -- representation types of
412 , repr_comps :: [CompRepr] -- components
414 data ConRepr = ConRepr DataCon ProdRepr
416 data SumRepr = EmptySum
418 | Sum { repr_sum_tc :: TyCon -- representation sum tycon
419 , repr_psum_tc :: TyCon -- PData representation tycon
420 , repr_sel_ty :: Type -- type of selector
421 , repr_con_tys :: [Type] -- representation types of
422 , repr_cons :: [ConRepr] -- components
425 tyConRepr :: TyCon -> VM SumRepr
426 tyConRepr tc = sum_repr (tyConDataCons tc)
428 sum_repr [] = return EmptySum
429 sum_repr [con] = liftM UnarySum (con_repr con)
431 rs <- mapM con_repr cons
432 sum_tc <- builtin (sumTyCon arity)
433 tys <- mapM conReprType rs
434 (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
435 sel_ty <- builtin (selTy arity)
436 return $ Sum { repr_sum_tc = sum_tc
437 , repr_psum_tc = psum_tc
438 , repr_sel_ty = sel_ty
445 con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
447 prod_repr [] = return EmptyProd
448 prod_repr [ty] = liftM UnaryProd (comp_repr ty)
450 rs <- mapM comp_repr tys
451 tup_tc <- builtin (prodTyCon arity)
452 tys' <- mapM compReprType rs
453 (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
454 return $ Prod { repr_tup_tc = tup_tc
455 , repr_ptup_tc = ptup_tc
456 , repr_comp_tys = tys'
462 comp_repr ty = liftM (Keep ty) (prDictOfType ty)
463 `orElseV` return (Wrap ty)
465 sumReprType :: SumRepr -> VM Type
466 sumReprType EmptySum = voidType
467 sumReprType (UnarySum r) = conReprType r
468 sumReprType (Sum { repr_sum_tc = sum_tc, repr_con_tys = tys })
469 = return $ mkTyConApp sum_tc tys
471 conReprType :: ConRepr -> VM Type
472 conReprType (ConRepr _ r) = prodReprType r
474 prodReprType :: ProdRepr -> VM Type
475 prodReprType EmptyProd = voidType
476 prodReprType (UnaryProd r) = compReprType r
477 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
478 = return $ mkTyConApp tup_tc tys
480 compReprType :: CompRepr -> VM Type
481 compReprType (Keep ty _) = return ty
482 compReprType (Wrap ty) = do
483 wrap_tc <- builtin wrapTyCon
484 return $ mkTyConApp wrap_tc [ty]
486 compOrigType :: CompRepr -> Type
487 compOrigType (Keep ty _) = ty
488 compOrigType (Wrap ty) = ty
490 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
491 buildToPRepr vect_tc repr_tc _ repr
493 let arg_ty = mkTyConApp vect_tc ty_args
494 res_ty <- mkPReprType arg_ty
495 arg <- newLocalVar (fsLit "x") arg_ty
496 result <- to_sum (Var arg) arg_ty res_ty repr
497 return $ Lam arg result
499 ty_args = mkTyVarTys (tyConTyVars vect_tc)
501 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
503 to_sum _ _ _ EmptySum
505 void <- builtin voidVar
506 return $ wrap_repr_inst $ Var void
508 to_sum arg arg_ty res_ty (UnarySum r)
510 (pat, vars, body) <- con_alt r
511 return $ mkWildCase arg arg_ty res_ty
512 [(pat, vars, wrap_repr_inst body)]
514 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
516 , repr_cons = cons })
518 alts <- mapM con_alt cons
519 let alts' = [(pat, vars, wrap_repr_inst
520 $ mkConApp sum_con (map Type tys ++ [body]))
521 | ((pat, vars, body), sum_con)
522 <- zip alts (tyConDataCons sum_tc)]
523 return $ mkWildCase arg arg_ty res_ty alts'
525 con_alt (ConRepr con r)
527 (vars, body) <- to_prod r
528 return (DataAlt con, vars, body)
532 void <- builtin voidVar
533 return ([], Var void)
535 to_prod (UnaryProd comp)
537 var <- newLocalVar (fsLit "x") (compOrigType comp)
538 body <- to_comp (Var var) comp
541 to_prod(Prod { repr_tup_tc = tup_tc
542 , repr_comp_tys = tys
543 , repr_comps = comps })
545 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
546 exprs <- zipWithM to_comp (map Var vars) comps
547 return (vars, mkConApp tup_con (map Type tys ++ exprs))
549 [tup_con] = tyConDataCons tup_tc
551 to_comp expr (Keep _ _) = return expr
552 to_comp expr (Wrap ty) = do
553 wrap_tc <- builtin wrapTyCon
554 return $ wrapNewTypeBody wrap_tc [ty] expr
557 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
558 buildFromPRepr vect_tc repr_tc _ repr
560 arg_ty <- mkPReprType res_ty
561 arg <- newLocalVar (fsLit "x") arg_ty
563 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
565 return $ Lam arg result
567 ty_args = mkTyVarTys (tyConTyVars vect_tc)
568 res_ty = mkTyConApp vect_tc ty_args
572 dummy <- builtin fromVoidVar
573 return $ Var dummy `App` Type res_ty
575 from_sum expr (UnarySum r) = from_con expr r
576 from_sum expr (Sum { repr_sum_tc = sum_tc
578 , repr_cons = cons })
580 vars <- newLocalVars (fsLit "x") tys
581 es <- zipWithM from_con (map Var vars) cons
582 return $ mkWildCase expr (exprType expr) res_ty
583 [(DataAlt con, [var], e)
584 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
586 from_con expr (ConRepr con r)
587 = from_prod expr (mkConApp con $ map Type ty_args) r
589 from_prod _ con EmptyProd = return con
590 from_prod expr con (UnaryProd r)
592 e <- from_comp expr r
595 from_prod expr con (Prod { repr_tup_tc = tup_tc
596 , repr_comp_tys = tys
600 vars <- newLocalVars (fsLit "y") tys
601 es <- zipWithM from_comp (map Var vars) comps
602 return $ mkWildCase expr (exprType expr) res_ty
603 [(DataAlt tup_con, vars, con `mkApps` es)]
605 [tup_con] = tyConDataCons tup_tc
607 from_comp expr (Keep _ _) = return expr
608 from_comp expr (Wrap ty)
610 wrap <- builtin wrapTyCon
611 return $ unwrapNewTypeBody wrap [ty] expr
614 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
615 buildToArrPRepr vect_tc prepr_tc pdata_tc r
617 arg_ty <- mkPDataType el_ty
618 res_ty <- mkPDataType =<< mkPReprType el_ty
619 arg <- newLocalVar (fsLit "xs") arg_ty
621 pdata_co <- mkBuiltinCo pdataTyCon
622 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
623 co = mkAppCoercion pdata_co
625 $ mkTyConApp repr_co ty_args
627 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
629 (vars, result) <- to_sum r
632 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
633 [(DataAlt pdata_dc, vars, mkCoerce co result)]
635 ty_args = mkTyVarTys $ tyConTyVars vect_tc
636 el_ty = mkTyConApp vect_tc ty_args
638 [pdata_dc] = tyConDataCons pdata_tc
642 pvoid <- builtin pvoidVar
643 return ([], Var pvoid)
644 to_sum (UnarySum r) = to_con r
645 to_sum (Sum { repr_psum_tc = psum_tc
646 , repr_sel_ty = sel_ty
651 (vars, exprs) <- mapAndUnzipM to_con cons
652 sel <- newLocalVar (fsLit "sel") sel_ty
653 return (sel : concat vars, mk_result (Var sel) exprs)
655 [psum_con] = tyConDataCons psum_tc
656 mk_result sel exprs = wrapFamInstBody psum_tc tys
658 $ map Type tys ++ (sel : exprs)
660 to_con (ConRepr _ r) = to_prod r
662 to_prod EmptyProd = do
663 pvoid <- builtin pvoidVar
664 return ([], Var pvoid)
665 to_prod (UnaryProd r)
667 pty <- mkPDataType (compOrigType r)
668 var <- newLocalVar (fsLit "x") pty
669 expr <- to_comp (Var var) r
672 to_prod (Prod { repr_ptup_tc = ptup_tc
673 , repr_comp_tys = tys
674 , repr_comps = comps })
676 ptys <- mapM (mkPDataType . compOrigType) comps
677 vars <- newLocalVars (fsLit "x") ptys
678 es <- zipWithM to_comp (map Var vars) comps
679 return (vars, mk_result es)
681 [ptup_con] = tyConDataCons ptup_tc
682 mk_result exprs = wrapFamInstBody ptup_tc tys
684 $ map Type tys ++ exprs
686 to_comp expr (Keep _ _) = return expr
688 -- FIXME: this is bound to be wrong!
689 to_comp expr (Wrap ty)
691 wrap_tc <- builtin wrapTyCon
692 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
693 return $ wrapNewTypeBody pwrap_tc [ty] expr
696 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
697 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
699 arg_ty <- mkPDataType =<< mkPReprType el_ty
700 res_ty <- mkPDataType el_ty
701 arg <- newLocalVar (fsLit "xs") arg_ty
703 pdata_co <- mkBuiltinCo pdataTyCon
704 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
705 co = mkAppCoercion pdata_co
706 $ mkTyConApp repr_co var_tys
708 scrut = mkCoerce co (Var arg)
710 mk_result args = wrapFamInstBody pdata_tc var_tys
712 $ map Type var_tys ++ args
714 (expr, _) <- fixV $ \ ~(_, args) ->
715 from_sum res_ty (mk_result args) scrut r
717 return $ Lam arg expr
719 -- (args, mk) <- from_sum res_ty scrut r
721 -- let result = wrapFamInstBody pdata_tc var_tys
722 -- . mkConApp pdata_dc
723 -- $ map Type var_tys ++ args
725 -- return $ Lam arg (mk result)
727 var_tys = mkTyVarTys $ tyConTyVars vect_tc
728 el_ty = mkTyConApp vect_tc var_tys
730 [pdata_con] = tyConDataCons pdata_tc
732 from_sum _ res _ EmptySum = return (res, [])
733 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
734 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
735 , repr_sel_ty = sel_ty
737 , repr_cons = cons })
739 sel <- newLocalVar (fsLit "sel") sel_ty
740 ptys <- mapM mkPDataType tys
741 vars <- newLocalVars (fsLit "xs") ptys
742 (res', args) <- fold from_con res_ty res (map Var vars) cons
743 let scrut = unwrapFamInstScrut psum_tc tys expr
744 body = mkWildCase scrut (exprType scrut) res_ty
745 [(DataAlt psum_con, sel : vars, res')]
746 return (body, Var sel : args)
748 [psum_con] = tyConDataCons psum_tc
751 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
753 from_prod _ res _ EmptyProd = return (res, [])
754 from_prod res_ty res expr (UnaryProd r)
755 = from_comp res_ty res expr r
756 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
757 , repr_comp_tys = tys
758 , repr_comps = comps })
760 ptys <- mapM mkPDataType tys
761 vars <- newLocalVars (fsLit "ys") ptys
762 (res', args) <- fold from_comp res_ty res (map Var vars) comps
763 let scrut = unwrapFamInstScrut ptup_tc tys expr
764 body = mkWildCase scrut (exprType scrut) res_ty
765 [(DataAlt ptup_con, vars, res')]
768 [ptup_con] = tyConDataCons ptup_tc
770 from_comp _ res expr (Keep _ _) = return (res, [expr])
771 from_comp _ res expr (Wrap ty)
773 wrap_tc <- builtin wrapTyCon
774 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
775 return (res, [unwrapNewTypeBody pwrap_tc [ty]
776 $ unwrapFamInstScrut pwrap_tc [ty] expr])
778 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
780 f' (expr, r) (res, args) = do
781 (res', args') <- f res_ty res expr r
782 return (res', args' ++ args)
784 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
785 buildPRDict vect_tc prepr_tc _ r
788 pr_co <- mkBuiltinCo prTyCon
789 let co = mkAppCoercion pr_co
791 $ mkTyConApp arg_co ty_args
792 return (mkCoerce co dict)
794 ty_args = mkTyVarTys (tyConTyVars vect_tc)
795 Just arg_co = tyConFamilyCoercion_maybe prepr_tc
797 sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
798 sum_dict (UnarySum r) = con_dict r
799 sum_dict (Sum { repr_sum_tc = sum_tc
804 dicts <- mapM con_dict cons
805 dfun <- prDFunOfTyCon sum_tc
806 return $ dfun `mkTyApps` tys `mkApps` dicts
808 con_dict (ConRepr _ r) = prod_dict r
810 prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
811 prod_dict (UnaryProd r) = comp_dict r
812 prod_dict (Prod { repr_tup_tc = tup_tc
813 , repr_comp_tys = tys
814 , repr_comps = comps })
816 dicts <- mapM comp_dict comps
817 dfun <- prDFunOfTyCon tup_tc
818 return $ dfun `mkTyApps` tys `mkApps` dicts
820 comp_dict (Keep _ pr) = return pr
821 comp_dict (Wrap ty) = wrapPR ty
824 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
825 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
827 name' <- cloneName mkPDataTyConOcc orig_name
828 rhs <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
829 pdata <- builtin pdataTyCon
831 liftDs $ buildAlgTyCon name'
833 [] -- no stupid theta
835 rec_flag -- FIXME: is this ok?
836 False -- FIXME: no generics
837 False -- not GADT syntax
838 (Just $ mk_fam_inst pdata vect_tc)
840 orig_name = tyConName orig_tc
841 tyvars = tyConTyVars vect_tc
842 rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
845 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
846 buildPDataTyConRhs orig_name vect_tc repr_tc repr
848 data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
849 return $ DataTyCon { data_cons = [data_con], is_enum = False }
851 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
852 buildPDataDataCon orig_name vect_tc repr_tc repr
854 dc_name <- cloneName mkPDataDataConOcc orig_name
855 comp_tys <- sum_tys repr
857 liftDs $ buildDataCon dc_name
859 (map (const HsNoBang) comp_tys)
860 [] -- no field labels
862 [] -- no existentials
866 (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
869 tvs = tyConTyVars vect_tc
871 sum_tys EmptySum = return []
872 sum_tys (UnarySum r) = con_tys r
873 sum_tys (Sum { repr_sel_ty = sel_ty
874 , repr_cons = cons })
875 = liftM (sel_ty :) (concatMapM con_tys cons)
877 con_tys (ConRepr _ r) = prod_tys r
879 prod_tys EmptyProd = return []
880 prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
881 prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
883 comp_ty r = mkPDataType (compOrigType r)
886 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr
888 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
890 vectDataConWorkers orig_tc vect_tc pdata_tc
891 buildPADict vect_tc prepr_tc pdata_tc repr
893 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
894 vectDataConWorkers orig_tc vect_tc arr_tc
897 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
898 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
901 (tail $ tails rep_tys)
902 mapM_ (uncurry hoistBinding) bs
904 tyvars = tyConTyVars vect_tc
905 var_tys = mkTyVarTys tyvars
906 ty_args = map Type var_tys
907 res_ty = mkTyConApp vect_tc var_tys
909 cons = tyConDataCons vect_tc
911 [arr_dc] = tyConDataCons arr_tc
913 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
916 mk_data_con con tys pre post
917 = liftM2 (,) (vect_data_con con)
918 (lift_data_con tys pre post (mkDataConTag con))
920 sel_replicate len tag
922 rep <- builtin (selReplicate arity)
923 return [rep `mkApps` [len, tag]]
925 | otherwise = return []
927 vect_data_con con = return $ mkConApp con ty_args
928 lift_data_con tys pre_tys post_tys tag
930 len <- builtin liftingContext
931 args <- mapM (newLocalVar (fsLit "xs"))
932 =<< mapM mkPDataType tys
934 sel <- sel_replicate (Var len) tag
936 pre <- mapM emptyPD (concat pre_tys)
937 post <- mapM emptyPD (concat post_tys)
939 return . mkLams (len : args)
940 . wrapFamInstBody arr_tc var_tys
942 $ ty_args ++ sel ++ pre ++ map Var args ++ post
944 def_worker data_con arg_tys mk_body
946 arity <- polyArity tyvars
949 . polyAbstract tyvars $ \args ->
950 liftM (mkLams (tyvars ++ args) . vectorised)
951 $ buildClosures tyvars [] arg_tys res_ty mk_body
953 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
954 let vect_worker = raw_worker `setIdUnfolding`
955 mkInlineRule body (Just arity)
956 defGlobalVar orig_worker vect_worker
957 return (vect_worker, body)
959 orig_worker = dataConWorkId data_con
961 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
962 buildPADict vect_tc prepr_tc arr_tc repr
963 = polyAbstract tvs $ \args ->
965 method_ids <- mapM (method args) paMethods
967 pa_tc <- builtin paTyCon
968 pa_dc <- builtin paDataCon
969 let dict = mkLams (tvs ++ args)
971 $ Type inst_ty : map (method_call args) method_ids
973 dfun_ty = mkForAllTys tvs
974 $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
976 raw_dfun <- newExportedVar dfun_name dfun_ty
977 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
978 `setInlinePragma` dfunInlinePragma
980 hoistBinding dfun dict
983 tvs = tyConTyVars vect_tc
984 arg_tys = mkTyVarTys tvs
985 inst_ty = mkTyConApp vect_tc arg_tys
987 dfun_name = mkPADFunOcc (getOccName vect_tc)
989 method args (name, build)
992 expr <- build vect_tc prepr_tc arr_tc repr
993 let body = mkLams (tvs ++ args) expr
994 raw_var <- newExportedVar (method_name name) (exprType body)
996 `setIdUnfolding` mkInlineRule body (Just (length args))
997 `setInlinePragma` alwaysInlinePragma
998 hoistBinding var body
1001 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1003 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1006 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1007 paMethods = [("dictPRepr", buildPRDict),
1008 ("toPRepr", buildToPRepr),
1009 ("fromPRepr", buildFromPRepr),
1010 ("toArrPRepr", buildToArrPRepr),
1011 ("fromArrPRepr", buildFromArrPRepr)]
1014 -- | Split the given tycons into two sets depending on whether they have to be
1015 -- converted (first list) or not (second list). The first argument contains
1016 -- information about the conversion status of external tycons:
1018 -- * tycons which have converted versions are mapped to True
1019 -- * tycons which are not changed by vectorisation are mapped to False
1020 -- * tycons which can't be converted are not elements of the map
1022 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1023 classifyTyCons = classify [] []
1025 classify conv keep _ [] = (conv, keep)
1026 classify conv keep cs ((tcs, ds) : rs)
1027 | can_convert && must_convert
1028 = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1030 = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1032 = classify conv keep cs rs
1034 refs = ds `delListFromUniqSet` tcs
1036 can_convert = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1037 must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1039 convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1041 -- | Compute mutually recursive groups of tycons in topological order
1043 tyConGroups :: [TyCon] -> [TyConGroup]
1044 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1046 edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1047 , let ds = tyConsOfTyCon tc]
1049 mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1050 mk_grp (CyclicSCC els) = (tcs, unionManyUniqSets dss)
1052 (tcs, dss) = unzip els
1054 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1056 = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1058 tyConsOfType :: Type -> UniqSet TyCon
1060 | Just ty' <- coreView ty = tyConsOfType ty'
1061 tyConsOfType (TyVarTy _) = emptyUniqSet
1062 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1064 extend | isUnLiftedTyCon tc
1065 || isTupleTyCon tc = id
1067 | otherwise = (`addOneToUniqSet` tc)
1069 tyConsOfType (AppTy a b) = tyConsOfType a `unionUniqSets` tyConsOfType b
1070 tyConsOfType (FunTy a b) = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1071 `addOneToUniqSet` funTyCon
1072 tyConsOfType (ForAllTy _ ty) = tyConsOfType ty
1073 tyConsOfType other = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1075 tyConsOfTypes :: [Type] -> UniqSet TyCon
1076 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1079 -- ----------------------------------------------------------------------------
1082 -- | Build an expression that calls the vectorised version of some
1083 -- function from a `Closure`.
1087 -- \(x :: Double) ->
1088 -- \(y :: Double) ->
1089 -- ($v_foo $: x) $: y
1092 -- We use the type of the original binding to work out how many
1093 -- outer lambdas to add.
1096 :: Type -- ^ The type of the original binding.
1097 -> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
1100 -- Convert the type to the core view if it isn't already.
1102 | Just ty' <- coreView ty
1105 -- For each function constructor in the original type we add an outer
1106 -- lambda to bind the parameter variable, and an inner application of it.
1107 fromVect (FunTy arg_ty res_ty) expr
1109 arg <- newLocalVar (fsLit "x") arg_ty
1110 varg <- toVect arg_ty (Var arg)
1111 varg_ty <- vectType arg_ty
1112 vres_ty <- vectType res_ty
1113 apply <- builtin applyVar
1114 body <- fromVect res_ty
1115 $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1116 return $ Lam arg body
1118 -- If the type isn't a function then it's time to call on the closure.
1120 = identityConv ty >> return expr
1123 toVect :: Type -> CoreExpr -> VM CoreExpr
1124 toVect ty expr = identityConv ty >> return expr
1127 identityConv :: Type -> VM ()
1128 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1129 identityConv (TyConApp tycon tys)
1131 mapM_ identityConv tys
1132 identityConvTyCon tycon
1133 identityConv _ = noV
1135 identityConvTyCon :: TyCon -> VM ()
1136 identityConvTyCon tc
1137 | isBoxedTupleTyCon tc = return ()
1138 | isUnLiftedTyCon tc = return ()
1140 tc' <- maybeV (lookupTyCon tc)
1141 if tc == tc' then return () else noV