1 {-# OPTIONS -fno-warn-missing-signatures #-}
3 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
4 -- arrSumArity, pdataCompTys, pdataCompVars,
12 import Vectorise.Monad
13 import Vectorise.Builtins
15 import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
20 import MkCore ( mkWildCase )
28 import FamInstEnv ( FamInst, mkLocalFamInst )
32 import Var ( Var, TyVar, varType, varName )
33 import Name ( Name, getOccName )
40 import Digraph ( SCC(..), stronglyConnCompFromEdgedVertices )
45 import MonadUtils ( zipWith3M, foldrM, concatMapM )
46 import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
51 dtrace s x = if debug then pprTrace "VectType" s x else x
53 -- ----------------------------------------------------------------------------
56 -- | Vectorise a type constructor.
57 vectTyCon :: TyCon -> VM TyCon
59 | isFunTyCon tc = builtin closureTyCon
60 | isBoxedTupleTyCon tc = return tc
61 | isUnLiftedTyCon tc = return tc
63 = maybeCantVectoriseM "Tycon not vectorised: " (ppr tc)
67 vectAndLiftType :: Type -> VM (Type, Type)
68 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
71 mdicts <- mapM paDictArgType tyvars
72 let dicts = [dict | Just dict <- mdicts]
73 vmono_ty <- vectType mono_ty
74 lmono_ty <- mkPDataType vmono_ty
75 return (abstractType tyvars dicts vmono_ty,
76 abstractType tyvars dicts lmono_ty)
78 (tyvars, mono_ty) = splitForAllTys ty
81 -- | Vectorise a type.
82 vectType :: Type -> VM Type
84 | Just ty' <- coreView ty
87 vectType (TyVarTy tv) = return $ TyVarTy tv
88 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
89 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
90 vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
91 (mapM vectAndBoxType [ty1,ty2])
93 -- For each quantified var we need to add a PA dictionary out the front of the type.
94 -- So forall a. C a => a -> a
95 -- turns into forall a. Cv a => PA a => a :-> a
96 vectType ty@(ForAllTy _ _)
98 -- split the type into the quantified vars, its dictionaries and the body.
99 let (tyvars, tyBody) = splitForAllTys ty
100 let (tyArgs, tyResult) = splitFunTys tyBody
102 let (tyArgs_dict, tyArgs_regular)
103 = partition isDictType tyArgs
105 -- vectorise the body.
106 let tyBody' = mkFunTys tyArgs_regular tyResult
107 tyBody'' <- vectType tyBody'
109 -- vectorise the dictionary parameters.
110 dictsVect <- mapM vectType tyArgs_dict
112 -- make a PA dictionary for each of the type variables.
113 dictsPA <- liftM catMaybes $ mapM paDictArgType tyvars
115 -- pack it all back together.
116 return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
118 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
121 -- | Add quantified vars and dictionary parameters to the front of a type.
122 abstractType :: [TyVar] -> [Type] -> Type -> Type
123 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
126 -- | Check if some type is a type class dictionary.
127 isDictType :: Type -> Bool
129 = case splitTyConApp_maybe ty of
130 Just (tyCon, _) -> isClassTyCon tyCon
134 -- ----------------------------------------------------------------------------
137 boxType :: Type -> VM Type
139 | Just (tycon, []) <- splitTyConApp_maybe ty
140 , isUnLiftedTyCon tycon
142 r <- lookupBoxedTyCon tycon
144 Just tycon' -> return $ mkTyConApp tycon' []
147 boxType ty = return ty
149 vectAndBoxType :: Type -> VM Type
150 vectAndBoxType ty = vectType ty >>= boxType
153 -- ----------------------------------------------------------------------------
156 type TyConGroup = ([TyCon], UniqSet TyCon)
158 -- | Vectorise a type environment.
159 -- The type environment contains all the type things defined in a module.
160 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
164 cs <- readGEnv $ mk_map . global_tycons
166 -- Split the list of TyCons into the ones we have to vectorise vs the
167 -- ones we can pass through unchanged. We also pass through algebraic
168 -- types that use non Haskell98 features, as we don't handle those.
169 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
170 keep_dcs = concatMap tyConDataCons keep_tcs
172 dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
174 zipWithM_ defTyCon keep_tcs keep_tcs
175 zipWithM_ defDataCon keep_dcs keep_dcs
177 new_tcs <- vectTyConDecls conv_tcs
179 dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
181 let orig_tcs = keep_tcs ++ conv_tcs
183 -- We don't need to make new representation types for dictionary
184 -- constructors. The constructors are always fully applied, and we don't
185 -- need to lift them to arrays as a dictionary of a particular type
186 -- always has the same value.
187 let vect_tcs = filter (not . isClassTyCon)
188 $ keep_tcs ++ new_tcs
190 dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
192 mapM_ dumpTycon $ new_tcs
195 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
197 defTyConPAs (zipLazy vect_tcs dfuns')
198 reprs <- mapM tyConRepr vect_tcs
199 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
200 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
203 $ zipWith5 buildTyConBindings
211 return (dfuns, binds, repr_tcs ++ pdata_tcs)
213 let all_new_tcs = new_tcs ++ inst_tcs
215 let new_env = extendTypeEnvList env
216 (map ATyCon all_new_tcs
217 ++ [ADataCon dc | tc <- all_new_tcs
218 , dc <- tyConDataCons tc])
220 return (new_env, map mkLocalFamInst inst_tcs, binds)
222 tycons = typeEnvTyCons env
223 groups = tyConGroups tycons
225 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
228 -- | Vectorise some (possibly recursively defined) type constructors.
229 vectTyConDecls :: [TyCon] -> VM [TyCon]
230 vectTyConDecls tcs = fixV $ \tcs' ->
232 mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
233 mapM vectTyConDecl tcs
235 dumpTycon :: TyCon -> VM ()
237 | Just cls <- tyConClass_maybe tycon
238 = dtrace (vcat [ ppr tycon
239 , ppr [(m, varType m) | m <- classMethods cls ]])
246 -- | Vectorise a single type construcrtor.
247 vectTyConDecl :: TyCon -> VM TyCon
249 -- a type class constructor.
250 -- TODO: check for no stupid theta, fds, assoc types.
252 , Just cls <- tyConClass_maybe tycon
254 = do -- make the name of the vectorised class tycon.
255 name' <- cloneName mkVectTyConOcc (tyConName tycon)
257 -- vectorise right of definition.
258 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
260 -- vectorise method selectors.
261 -- This also adds a mapping between the original and vectorised method selector
263 methods' <- mapM vectMethod
264 $ [(id, defMethSpecOfDefMeth meth)
265 | (id, meth) <- classOpItems cls]
267 -- keep the original recursiveness flag.
268 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
270 -- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
273 False -- include unfoldings on dictionary selectors.
274 name' -- new name V_T:Class
275 (tyConTyVars tycon) -- keep original type vars
276 [] -- no stupid theta
277 [] -- no functional dependencies
278 [] -- no associated types
279 methods' -- method info
280 rec_flag -- whether recursive
282 let tycon' = mkClassTyCon name'
291 -- a regular algebraic type constructor.
292 -- TODO: check for stupid theta, generaics, GADTS etc
294 = do name' <- cloneName mkVectTyConOcc (tyConName tycon)
295 rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
296 let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
298 liftDs $ buildAlgTyCon
300 (tyConTyVars tycon) -- keep original type vars.
301 [] -- no stupid theta.
302 rhs' -- new constructor defs.
303 rec_flag -- FIXME: is this ok?
304 False -- FIXME: no generics
305 False -- not GADT syntax
306 Nothing -- not a family instance
308 -- some other crazy thing that we don't handle.
310 = cantVectorise "Can't vectorise type constructor: " (ppr tycon)
313 -- | Vectorise a class method.
314 vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
315 vectMethod (id, defMeth)
317 -- Vectorise the method type.
318 typ' <- vectType (varType id)
320 -- Create a name for the vectorised method.
321 id' <- cloneId mkVectOcc id typ'
324 -- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
325 -- to the types of each method. However, the types we get back from vectType
326 -- above already already have these, so we need to chop them off here otherwise
327 -- we'll get two copies in the final version.
328 let (_tyvars, tyBody) = splitForAllTys typ'
329 let (_dict, tyRest) = splitFunTy tyBody
331 return (Var.varName id', defMeth, tyRest)
334 -- | Vectorise the RHS of an algebraic type.
335 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
336 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
340 data_cons' <- mapM vectDataCon data_cons
341 zipWithM_ defDataCon data_cons data_cons'
342 return $ DataTyCon { data_cons = data_cons'
347 = cantVectorise "Can't vectorise type definition:" (ppr tc)
350 -- | Vectorise a data constructor.
351 -- Vectorises its argument and return types.
352 vectDataCon :: DataCon -> VM DataCon
354 | not . null $ dataConExTyVars dc
355 = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
357 | not . null $ dataConEqSpec dc
358 = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
362 name' <- cloneName mkVectDataConOcc name
363 tycon' <- vectTyCon tycon
364 arg_tys <- mapM vectType rep_arg_tys
366 liftDs $ buildDataCon
369 (map (const HsNoBang) arg_tys) -- strictness annots on args.
370 [] -- no labelled fields
371 univ_tvs -- universally quantified vars
372 [] -- no existential tvs for now
373 [] -- no eq spec for now
375 arg_tys -- argument types
376 (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
377 tycon' -- representation tycon
379 name = dataConName dc
380 univ_tvs = dataConUnivTyVars dc
381 rep_arg_tys = dataConRepArgTys dc
382 tycon = dataConTyCon dc
384 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
385 mk_fam_inst fam_tc arg_tc
386 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
389 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
390 buildPReprTyCon orig_tc vect_tc repr
392 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
393 -- rhs_ty <- buildPReprType vect_tc
394 rhs_ty <- sumReprType repr
395 prepr_tc <- builtin preprTyCon
396 liftDs $ buildSynTyCon name
398 (SynonymTyCon rhs_ty)
400 (Just $ mk_fam_inst prepr_tc vect_tc)
402 tyvars = tyConTyVars vect_tc
404 data CompRepr = Keep Type
405 CoreExpr -- PR dictionary for the type
408 data ProdRepr = EmptyProd
410 | Prod { repr_tup_tc :: TyCon -- representation tuple tycon
411 , repr_ptup_tc :: TyCon -- PData representation tycon
412 , repr_comp_tys :: [Type] -- representation types of
413 , repr_comps :: [CompRepr] -- components
415 data ConRepr = ConRepr DataCon ProdRepr
417 data SumRepr = EmptySum
419 | Sum { repr_sum_tc :: TyCon -- representation sum tycon
420 , repr_psum_tc :: TyCon -- PData representation tycon
421 , repr_sel_ty :: Type -- type of selector
422 , repr_con_tys :: [Type] -- representation types of
423 , repr_cons :: [ConRepr] -- components
426 tyConRepr :: TyCon -> VM SumRepr
427 tyConRepr tc = sum_repr (tyConDataCons tc)
429 sum_repr [] = return EmptySum
430 sum_repr [con] = liftM UnarySum (con_repr con)
432 rs <- mapM con_repr cons
433 sum_tc <- builtin (sumTyCon arity)
434 tys <- mapM conReprType rs
435 (psum_tc, _) <- pdataReprTyCon (mkTyConApp sum_tc tys)
436 sel_ty <- builtin (selTy arity)
437 return $ Sum { repr_sum_tc = sum_tc
438 , repr_psum_tc = psum_tc
439 , repr_sel_ty = sel_ty
446 con_repr con = liftM (ConRepr con) (prod_repr (dataConRepArgTys con))
448 prod_repr [] = return EmptyProd
449 prod_repr [ty] = liftM UnaryProd (comp_repr ty)
451 rs <- mapM comp_repr tys
452 tup_tc <- builtin (prodTyCon arity)
453 tys' <- mapM compReprType rs
454 (ptup_tc, _) <- pdataReprTyCon (mkTyConApp tup_tc tys')
455 return $ Prod { repr_tup_tc = tup_tc
456 , repr_ptup_tc = ptup_tc
457 , repr_comp_tys = tys'
463 comp_repr ty = liftM (Keep ty) (prDictOfType ty)
464 `orElseV` return (Wrap ty)
466 sumReprType :: SumRepr -> VM Type
467 sumReprType EmptySum = voidType
468 sumReprType (UnarySum r) = conReprType r
469 sumReprType (Sum { repr_sum_tc = sum_tc, repr_con_tys = tys })
470 = return $ mkTyConApp sum_tc tys
472 conReprType :: ConRepr -> VM Type
473 conReprType (ConRepr _ r) = prodReprType r
475 prodReprType :: ProdRepr -> VM Type
476 prodReprType EmptyProd = voidType
477 prodReprType (UnaryProd r) = compReprType r
478 prodReprType (Prod { repr_tup_tc = tup_tc, repr_comp_tys = tys })
479 = return $ mkTyConApp tup_tc tys
481 compReprType :: CompRepr -> VM Type
482 compReprType (Keep ty _) = return ty
483 compReprType (Wrap ty) = do
484 wrap_tc <- builtin wrapTyCon
485 return $ mkTyConApp wrap_tc [ty]
487 compOrigType :: CompRepr -> Type
488 compOrigType (Keep ty _) = ty
489 compOrigType (Wrap ty) = ty
491 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
492 buildToPRepr vect_tc repr_tc _ repr
494 let arg_ty = mkTyConApp vect_tc ty_args
495 res_ty <- mkPReprType arg_ty
496 arg <- newLocalVar (fsLit "x") arg_ty
497 result <- to_sum (Var arg) arg_ty res_ty repr
498 return $ Lam arg result
500 ty_args = mkTyVarTys (tyConTyVars vect_tc)
502 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
504 to_sum _ _ _ EmptySum
506 void <- builtin voidVar
507 return $ wrap_repr_inst $ Var void
509 to_sum arg arg_ty res_ty (UnarySum r)
511 (pat, vars, body) <- con_alt r
512 return $ mkWildCase arg arg_ty res_ty
513 [(pat, vars, wrap_repr_inst body)]
515 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
517 , repr_cons = cons })
519 alts <- mapM con_alt cons
520 let alts' = [(pat, vars, wrap_repr_inst
521 $ mkConApp sum_con (map Type tys ++ [body]))
522 | ((pat, vars, body), sum_con)
523 <- zip alts (tyConDataCons sum_tc)]
524 return $ mkWildCase arg arg_ty res_ty alts'
526 con_alt (ConRepr con r)
528 (vars, body) <- to_prod r
529 return (DataAlt con, vars, body)
533 void <- builtin voidVar
534 return ([], Var void)
536 to_prod (UnaryProd comp)
538 var <- newLocalVar (fsLit "x") (compOrigType comp)
539 body <- to_comp (Var var) comp
542 to_prod(Prod { repr_tup_tc = tup_tc
543 , repr_comp_tys = tys
544 , repr_comps = comps })
546 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
547 exprs <- zipWithM to_comp (map Var vars) comps
548 return (vars, mkConApp tup_con (map Type tys ++ exprs))
550 [tup_con] = tyConDataCons tup_tc
552 to_comp expr (Keep _ _) = return expr
553 to_comp expr (Wrap ty) = do
554 wrap_tc <- builtin wrapTyCon
555 return $ wrapNewTypeBody wrap_tc [ty] expr
558 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
559 buildFromPRepr vect_tc repr_tc _ repr
561 arg_ty <- mkPReprType res_ty
562 arg <- newLocalVar (fsLit "x") arg_ty
564 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
566 return $ Lam arg result
568 ty_args = mkTyVarTys (tyConTyVars vect_tc)
569 res_ty = mkTyConApp vect_tc ty_args
573 dummy <- builtin fromVoidVar
574 return $ Var dummy `App` Type res_ty
576 from_sum expr (UnarySum r) = from_con expr r
577 from_sum expr (Sum { repr_sum_tc = sum_tc
579 , repr_cons = cons })
581 vars <- newLocalVars (fsLit "x") tys
582 es <- zipWithM from_con (map Var vars) cons
583 return $ mkWildCase expr (exprType expr) res_ty
584 [(DataAlt con, [var], e)
585 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
587 from_con expr (ConRepr con r)
588 = from_prod expr (mkConApp con $ map Type ty_args) r
590 from_prod _ con EmptyProd = return con
591 from_prod expr con (UnaryProd r)
593 e <- from_comp expr r
596 from_prod expr con (Prod { repr_tup_tc = tup_tc
597 , repr_comp_tys = tys
601 vars <- newLocalVars (fsLit "y") tys
602 es <- zipWithM from_comp (map Var vars) comps
603 return $ mkWildCase expr (exprType expr) res_ty
604 [(DataAlt tup_con, vars, con `mkApps` es)]
606 [tup_con] = tyConDataCons tup_tc
608 from_comp expr (Keep _ _) = return expr
609 from_comp expr (Wrap ty)
611 wrap <- builtin wrapTyCon
612 return $ unwrapNewTypeBody wrap [ty] expr
615 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
616 buildToArrPRepr vect_tc prepr_tc pdata_tc r
618 arg_ty <- mkPDataType el_ty
619 res_ty <- mkPDataType =<< mkPReprType el_ty
620 arg <- newLocalVar (fsLit "xs") arg_ty
622 pdata_co <- mkBuiltinCo pdataTyCon
623 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
624 co = mkAppCoercion pdata_co
626 $ mkTyConApp repr_co ty_args
628 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
630 (vars, result) <- to_sum r
633 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
634 [(DataAlt pdata_dc, vars, mkCoerce co result)]
636 ty_args = mkTyVarTys $ tyConTyVars vect_tc
637 el_ty = mkTyConApp vect_tc ty_args
639 [pdata_dc] = tyConDataCons pdata_tc
643 pvoid <- builtin pvoidVar
644 return ([], Var pvoid)
645 to_sum (UnarySum r) = to_con r
646 to_sum (Sum { repr_psum_tc = psum_tc
647 , repr_sel_ty = sel_ty
652 (vars, exprs) <- mapAndUnzipM to_con cons
653 sel <- newLocalVar (fsLit "sel") sel_ty
654 return (sel : concat vars, mk_result (Var sel) exprs)
656 [psum_con] = tyConDataCons psum_tc
657 mk_result sel exprs = wrapFamInstBody psum_tc tys
659 $ map Type tys ++ (sel : exprs)
661 to_con (ConRepr _ r) = to_prod r
663 to_prod EmptyProd = do
664 pvoid <- builtin pvoidVar
665 return ([], Var pvoid)
666 to_prod (UnaryProd r)
668 pty <- mkPDataType (compOrigType r)
669 var <- newLocalVar (fsLit "x") pty
670 expr <- to_comp (Var var) r
673 to_prod (Prod { repr_ptup_tc = ptup_tc
674 , repr_comp_tys = tys
675 , repr_comps = comps })
677 ptys <- mapM (mkPDataType . compOrigType) comps
678 vars <- newLocalVars (fsLit "x") ptys
679 es <- zipWithM to_comp (map Var vars) comps
680 return (vars, mk_result es)
682 [ptup_con] = tyConDataCons ptup_tc
683 mk_result exprs = wrapFamInstBody ptup_tc tys
685 $ map Type tys ++ exprs
687 to_comp expr (Keep _ _) = return expr
689 -- FIXME: this is bound to be wrong!
690 to_comp expr (Wrap ty)
692 wrap_tc <- builtin wrapTyCon
693 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
694 return $ wrapNewTypeBody pwrap_tc [ty] expr
697 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
698 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
700 arg_ty <- mkPDataType =<< mkPReprType el_ty
701 res_ty <- mkPDataType el_ty
702 arg <- newLocalVar (fsLit "xs") arg_ty
704 pdata_co <- mkBuiltinCo pdataTyCon
705 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
706 co = mkAppCoercion pdata_co
707 $ mkTyConApp repr_co var_tys
709 scrut = mkCoerce co (Var arg)
711 mk_result args = wrapFamInstBody pdata_tc var_tys
713 $ map Type var_tys ++ args
715 (expr, _) <- fixV $ \ ~(_, args) ->
716 from_sum res_ty (mk_result args) scrut r
718 return $ Lam arg expr
720 -- (args, mk) <- from_sum res_ty scrut r
722 -- let result = wrapFamInstBody pdata_tc var_tys
723 -- . mkConApp pdata_dc
724 -- $ map Type var_tys ++ args
726 -- return $ Lam arg (mk result)
728 var_tys = mkTyVarTys $ tyConTyVars vect_tc
729 el_ty = mkTyConApp vect_tc var_tys
731 [pdata_con] = tyConDataCons pdata_tc
733 from_sum _ res _ EmptySum = return (res, [])
734 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
735 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
736 , repr_sel_ty = sel_ty
738 , repr_cons = cons })
740 sel <- newLocalVar (fsLit "sel") sel_ty
741 ptys <- mapM mkPDataType tys
742 vars <- newLocalVars (fsLit "xs") ptys
743 (res', args) <- fold from_con res_ty res (map Var vars) cons
744 let scrut = unwrapFamInstScrut psum_tc tys expr
745 body = mkWildCase scrut (exprType scrut) res_ty
746 [(DataAlt psum_con, sel : vars, res')]
747 return (body, Var sel : args)
749 [psum_con] = tyConDataCons psum_tc
752 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
754 from_prod _ res _ EmptyProd = return (res, [])
755 from_prod res_ty res expr (UnaryProd r)
756 = from_comp res_ty res expr r
757 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
758 , repr_comp_tys = tys
759 , repr_comps = comps })
761 ptys <- mapM mkPDataType tys
762 vars <- newLocalVars (fsLit "ys") ptys
763 (res', args) <- fold from_comp res_ty res (map Var vars) comps
764 let scrut = unwrapFamInstScrut ptup_tc tys expr
765 body = mkWildCase scrut (exprType scrut) res_ty
766 [(DataAlt ptup_con, vars, res')]
769 [ptup_con] = tyConDataCons ptup_tc
771 from_comp _ res expr (Keep _ _) = return (res, [expr])
772 from_comp _ res expr (Wrap ty)
774 wrap_tc <- builtin wrapTyCon
775 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
776 return (res, [unwrapNewTypeBody pwrap_tc [ty]
777 $ unwrapFamInstScrut pwrap_tc [ty] expr])
779 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
781 f' (expr, r) (res, args) = do
782 (res', args') <- f res_ty res expr r
783 return (res', args' ++ args)
785 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
786 buildPRDict vect_tc prepr_tc _ r
789 pr_co <- mkBuiltinCo prTyCon
790 let co = mkAppCoercion pr_co
792 $ mkTyConApp arg_co ty_args
793 return (mkCoerce co dict)
795 ty_args = mkTyVarTys (tyConTyVars vect_tc)
796 Just arg_co = tyConFamilyCoercion_maybe prepr_tc
798 sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
799 sum_dict (UnarySum r) = con_dict r
800 sum_dict (Sum { repr_sum_tc = sum_tc
805 dicts <- mapM con_dict cons
806 dfun <- prDFunOfTyCon sum_tc
807 return $ dfun `mkTyApps` tys `mkApps` dicts
809 con_dict (ConRepr _ r) = prod_dict r
811 prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
812 prod_dict (UnaryProd r) = comp_dict r
813 prod_dict (Prod { repr_tup_tc = tup_tc
814 , repr_comp_tys = tys
815 , repr_comps = comps })
817 dicts <- mapM comp_dict comps
818 dfun <- prDFunOfTyCon tup_tc
819 return $ dfun `mkTyApps` tys `mkApps` dicts
821 comp_dict (Keep _ pr) = return pr
822 comp_dict (Wrap ty) = wrapPR ty
825 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
826 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
828 name' <- cloneName mkPDataTyConOcc orig_name
829 rhs <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
830 pdata <- builtin pdataTyCon
832 liftDs $ buildAlgTyCon name'
834 [] -- no stupid theta
836 rec_flag -- FIXME: is this ok?
837 False -- FIXME: no generics
838 False -- not GADT syntax
839 (Just $ mk_fam_inst pdata vect_tc)
841 orig_name = tyConName orig_tc
842 tyvars = tyConTyVars vect_tc
843 rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
846 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
847 buildPDataTyConRhs orig_name vect_tc repr_tc repr
849 data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
850 return $ DataTyCon { data_cons = [data_con], is_enum = False }
852 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
853 buildPDataDataCon orig_name vect_tc repr_tc repr
855 dc_name <- cloneName mkPDataDataConOcc orig_name
856 comp_tys <- sum_tys repr
858 liftDs $ buildDataCon dc_name
860 (map (const HsNoBang) comp_tys)
861 [] -- no field labels
863 [] -- no existentials
867 (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
870 tvs = tyConTyVars vect_tc
872 sum_tys EmptySum = return []
873 sum_tys (UnarySum r) = con_tys r
874 sum_tys (Sum { repr_sel_ty = sel_ty
875 , repr_cons = cons })
876 = liftM (sel_ty :) (concatMapM con_tys cons)
878 con_tys (ConRepr _ r) = prod_tys r
880 prod_tys EmptyProd = return []
881 prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
882 prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
884 comp_ty r = mkPDataType (compOrigType r)
887 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr
889 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
891 vectDataConWorkers orig_tc vect_tc pdata_tc
892 buildPADict vect_tc prepr_tc pdata_tc repr
894 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
895 vectDataConWorkers orig_tc vect_tc arr_tc
898 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
899 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
902 (tail $ tails rep_tys)
903 mapM_ (uncurry hoistBinding) bs
905 tyvars = tyConTyVars vect_tc
906 var_tys = mkTyVarTys tyvars
907 ty_args = map Type var_tys
908 res_ty = mkTyConApp vect_tc var_tys
910 cons = tyConDataCons vect_tc
912 [arr_dc] = tyConDataCons arr_tc
914 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
917 mk_data_con con tys pre post
918 = liftM2 (,) (vect_data_con con)
919 (lift_data_con tys pre post (mkDataConTag con))
921 sel_replicate len tag
923 rep <- builtin (selReplicate arity)
924 return [rep `mkApps` [len, tag]]
926 | otherwise = return []
928 vect_data_con con = return $ mkConApp con ty_args
929 lift_data_con tys pre_tys post_tys tag
931 len <- builtin liftingContext
932 args <- mapM (newLocalVar (fsLit "xs"))
933 =<< mapM mkPDataType tys
935 sel <- sel_replicate (Var len) tag
937 pre <- mapM emptyPD (concat pre_tys)
938 post <- mapM emptyPD (concat post_tys)
940 return . mkLams (len : args)
941 . wrapFamInstBody arr_tc var_tys
943 $ ty_args ++ sel ++ pre ++ map Var args ++ post
945 def_worker data_con arg_tys mk_body
947 arity <- polyArity tyvars
950 . polyAbstract tyvars $ \args ->
951 liftM (mkLams (tyvars ++ args) . vectorised)
952 $ buildClosures tyvars [] arg_tys res_ty mk_body
954 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
955 let vect_worker = raw_worker `setIdUnfolding`
956 mkInlineRule body (Just arity)
957 defGlobalVar orig_worker vect_worker
958 return (vect_worker, body)
960 orig_worker = dataConWorkId data_con
962 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
963 buildPADict vect_tc prepr_tc arr_tc repr
964 = polyAbstract tvs $ \args ->
966 method_ids <- mapM (method args) paMethods
968 pa_tc <- builtin paTyCon
969 pa_dc <- builtin paDataCon
970 let dict = mkLams (tvs ++ args)
972 $ Type inst_ty : map (method_call args) method_ids
974 dfun_ty = mkForAllTys tvs
975 $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
977 raw_dfun <- newExportedVar dfun_name dfun_ty
978 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
979 `setInlinePragma` dfunInlinePragma
981 hoistBinding dfun dict
984 tvs = tyConTyVars vect_tc
985 arg_tys = mkTyVarTys tvs
986 inst_ty = mkTyConApp vect_tc arg_tys
988 dfun_name = mkPADFunOcc (getOccName vect_tc)
990 method args (name, build)
993 expr <- build vect_tc prepr_tc arr_tc repr
994 let body = mkLams (tvs ++ args) expr
995 raw_var <- newExportedVar (method_name name) (exprType body)
997 `setIdUnfolding` mkInlineRule body (Just (length args))
998 `setInlinePragma` alwaysInlinePragma
999 hoistBinding var body
1002 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
1004 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
1007 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
1008 paMethods = [("dictPRepr", buildPRDict),
1009 ("toPRepr", buildToPRepr),
1010 ("fromPRepr", buildFromPRepr),
1011 ("toArrPRepr", buildToArrPRepr),
1012 ("fromArrPRepr", buildFromArrPRepr)]
1015 -- | Split the given tycons into two sets depending on whether they have to be
1016 -- converted (first list) or not (second list). The first argument contains
1017 -- information about the conversion status of external tycons:
1019 -- * tycons which have converted versions are mapped to True
1020 -- * tycons which are not changed by vectorisation are mapped to False
1021 -- * tycons which can't be converted are not elements of the map
1023 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
1024 classifyTyCons = classify [] []
1026 classify conv keep _ [] = (conv, keep)
1027 classify conv keep cs ((tcs, ds) : rs)
1028 | can_convert && must_convert
1029 = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
1031 = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
1033 = classify conv keep cs rs
1035 refs = ds `delListFromUniqSet` tcs
1037 can_convert = isNullUFM (refs `minusUFM` cs) && all convertable tcs
1038 must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
1040 convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
1042 -- | Compute mutually recursive groups of tycons in topological order
1044 tyConGroups :: [TyCon] -> [TyConGroup]
1045 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
1047 edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
1048 , let ds = tyConsOfTyCon tc]
1050 mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
1051 mk_grp (CyclicSCC els) = (tcs, unionManyUniqSets dss)
1053 (tcs, dss) = unzip els
1055 tyConsOfTyCon :: TyCon -> UniqSet TyCon
1057 = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
1059 tyConsOfType :: Type -> UniqSet TyCon
1061 | Just ty' <- coreView ty = tyConsOfType ty'
1062 tyConsOfType (TyVarTy _) = emptyUniqSet
1063 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
1065 extend | isUnLiftedTyCon tc
1066 || isTupleTyCon tc = id
1068 | otherwise = (`addOneToUniqSet` tc)
1070 tyConsOfType (AppTy a b) = tyConsOfType a `unionUniqSets` tyConsOfType b
1071 tyConsOfType (FunTy a b) = (tyConsOfType a `unionUniqSets` tyConsOfType b)
1072 `addOneToUniqSet` funTyCon
1073 tyConsOfType (ForAllTy _ ty) = tyConsOfType ty
1074 tyConsOfType other = pprPanic "ClosureConv.tyConsOfType" $ ppr other
1076 tyConsOfTypes :: [Type] -> UniqSet TyCon
1077 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
1080 -- ----------------------------------------------------------------------------
1083 -- | Build an expression that calls the vectorised version of some
1084 -- function from a `Closure`.
1088 -- \(x :: Double) ->
1089 -- \(y :: Double) ->
1090 -- ($v_foo $: x) $: y
1093 -- We use the type of the original binding to work out how many
1094 -- outer lambdas to add.
1097 :: Type -- ^ The type of the original binding.
1098 -> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
1101 -- Convert the type to the core view if it isn't already.
1103 | Just ty' <- coreView ty
1106 -- For each function constructor in the original type we add an outer
1107 -- lambda to bind the parameter variable, and an inner application of it.
1108 fromVect (FunTy arg_ty res_ty) expr
1110 arg <- newLocalVar (fsLit "x") arg_ty
1111 varg <- toVect arg_ty (Var arg)
1112 varg_ty <- vectType arg_ty
1113 vres_ty <- vectType res_ty
1114 apply <- builtin applyVar
1115 body <- fromVect res_ty
1116 $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
1117 return $ Lam arg body
1119 -- If the type isn't a function then it's time to call on the closure.
1121 = identityConv ty >> return expr
1124 toVect :: Type -> CoreExpr -> VM CoreExpr
1125 toVect ty expr = identityConv ty >> return expr
1128 identityConv :: Type -> VM ()
1129 identityConv ty | Just ty' <- coreView ty = identityConv ty'
1130 identityConv (TyConApp tycon tys)
1132 mapM_ identityConv tys
1133 identityConvTyCon tycon
1134 identityConv _ = noV
1136 identityConvTyCon :: TyCon -> VM ()
1137 identityConvTyCon tc
1138 | isBoxedTupleTyCon tc = return ()
1139 | isUnLiftedTyCon tc = return ()
1141 tc' <- maybeV (lookupTyCon tc)
1142 if tc == tc' then return () else noV