X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectUtils.hs;h=cbca02f49e54317c119b0efe3bf54fc085b7d7a7;hb=fede30f987edef37bc62a07bdcdf3e7b5b951475;hp=0df16722bd2df457796752f18b257075ad6e9238;hpb=a2cc9611495710ab00355ebbd8e3d526b1a7d617;p=ghc-hetmet.git diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index 0df1672..cbca02f 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -4,6 +4,7 @@ module VectUtils ( mkPADictType, mkPArrayType, paDictArgType, paDictOfType, paMethod, lengthPA, replicatePA, emptyPA, + polyAbstract, polyApply, lookupPArrayFamInst, hoistExpr, takeHoisted ) where @@ -12,18 +13,24 @@ module VectUtils ( import VectMonad +import DsUtils import CoreSyn import CoreUtils import Type import TypeRep import TyCon +import DataCon ( dataConWrapId ) import Var +import Id ( mkWildId ) +import MkId ( unwrapFamInstScrut ) import PrelNames +import TysWiredIn +import BasicTypes ( Boxity(..) ) import Outputable import FastString -import Control.Monad ( liftM ) +import Control.Monad ( liftM, zipWithM_ ) collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type]) collectAnnTypeArgs expr = go expr [] @@ -42,7 +49,7 @@ isAnnTypeArg (_, AnnType t) = True isAnnTypeArg _ = False isClosureTyCon :: TyCon -> Bool -isClosureTyCon tc = tyConUnique tc == closureTyConKey +isClosureTyCon tc = tyConName tc == closureTyConName splitClosureTy :: Type -> (Type, Type) splitClosureTy ty @@ -52,6 +59,17 @@ splitClosureTy ty | otherwise = pprPanic "splitClosureTy" (ppr ty) +isPArrayTyCon :: TyCon -> Bool +isPArrayTyCon tc = tyConName tc == parrayTyConName + +splitPArrayTy :: Type -> Type +splitPArrayTy ty + | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty + , isPArrayTyCon tc + = arg_ty + + | otherwise = pprPanic "splitPArrayTy" (ppr ty) + mkPADictType :: Type -> VM Type mkPADictType ty = do @@ -117,7 +135,9 @@ paMethod method ty return $ mkApps (Var fn) [Type ty, dict] lengthPA :: CoreExpr -> VM CoreExpr -lengthPA x = liftM (`App` x) (paMethod lengthPAVar (exprType x)) +lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty) + where + ty = splitPArrayTy (exprType x) replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr replicatePA len x = liftM (`mkApps` [len,x]) @@ -126,6 +146,62 @@ replicatePA len x = liftM (`mkApps` [len,x]) emptyPA :: Type -> VM CoreExpr emptyPA = paMethod emptyPAVar +type Vect a = (a,a) +type VVar = Vect Var +type VExpr = Vect CoreExpr + +vectorised :: Vect a -> a +vectorised = fst + +lifted :: Vect a -> a +lifted = snd + +mapVect :: (a -> b) -> Vect a -> Vect b +mapVect f (x,y) = (f x, f y) + +newLocalVVar :: FastString -> Type -> VM VVar +newLocalVVar fs vty + = do + lty <- mkPArrayType vty + vv <- newLocalVar fs vty + lv <- newLocalVar fs lty + return (vv,lv) + +vVar :: VVar -> VExpr +vVar = mapVect Var + +mkVLams :: [VVar] -> VExpr -> VExpr +mkVLams vvs (ve,le) = (mkLams vs ve, mkLams ls le) + where + (vs,ls) = unzip vvs + +mkVVarApps :: Var -> VExpr -> [VVar] -> VExpr +mkVVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls)) + where + (vs,ls) = unzip vvs + +polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a +polyAbstract tvs p + = localV + $ do + mdicts <- mapM mk_dict_var tvs + zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts + p (mk_lams mdicts) + where + mk_dict_var tv = do + r <- paDictArgType tv + case r of + Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty) + Nothing -> return Nothing + + mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts]) + +polyApply :: CoreExpr -> [Type] -> VM CoreExpr +polyApply expr tys + = do + dicts <- mapM paDictOfType tys + return $ expr `mkTyApps` tys `mkApps` dicts + lookupPArrayFamInst :: Type -> VM (TyCon, [Type]) lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty]) @@ -137,6 +213,20 @@ hoistExpr fs expr env { global_bindings = (var, expr) : global_bindings env } return var +hoistPolyExpr :: FastString -> [TyVar] -> CoreExpr -> VM CoreExpr +hoistPolyExpr fs tvs expr + = do + poly_expr <- closedV . polyAbstract tvs $ \abstract -> return (abstract expr) + fn <- hoistExpr fs poly_expr + polyApply (Var fn) (mkTyVarTys tvs) + +hoistPolyVExpr :: FastString -> [TyVar] -> VExpr -> VM VExpr +hoistPolyVExpr fs tvs (ve, le) + = do + ve' <- hoistPolyExpr ('v' `consFS` fs) tvs ve + le' <- hoistPolyExpr ('l' `consFS` fs) tvs le + return (ve',le') + takeHoisted :: VM [(Var, CoreExpr)] takeHoisted = do @@ -144,3 +234,81 @@ takeHoisted setGEnv $ env { global_bindings = [] } return $ global_bindings env + +mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr +mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv) + = do + dict <- paDictOfType env_ty + mkv <- builtin mkClosureVar + mkl <- builtin mkClosurePVar + return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv], + Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv]) + +-- (clo , aclo (Arr lc xs1 ... xsn) ) +-- where +-- f = \env v -> case env of -> e x1 ... xn v +-- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v + +buildClosure :: [TyVar] -> Var -> [VVar] -> VVar -> VExpr -> VM VExpr +buildClosure tvs lv vars arg body + = do + (env_ty, env, bind) <- buildEnv lv vars + env_bndr <- newLocalVVar FSLIT("env") env_ty + + fn <- hoistPolyVExpr FSLIT("fn") tvs + . mkVLams [env_bndr, arg] + . bind (vVar env_bndr) + $ mkVVarApps lv body (vars ++ [arg]) + + mkClosure arg_ty res_ty env_ty fn env + + where + arg_ty = idType (vectorised arg) + res_ty = exprType (vectorised body) + + +buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr) +buildEnv lv vvs + = do + let (ty, venv, vbind) = mkVectEnv tys vs + (lenv, lbind) <- mkLiftEnv lv tys ls + return (ty, (venv, lenv), + \(venv,lenv) (vbody,lbody) -> (vbind venv vbody, lbind lenv lbody)) + where + (vs,ls) = unzip vvs + tys = map idType vs + +mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr) +mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body) +mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body) +mkVectEnv tys vs = (ty, mkCoreTup (map Var vs), + \env body -> Case env (mkWildId ty) (exprType body) + [(DataAlt (tupleCon Boxed (length vs)), vs, body)]) + where + ty = mkCoreTupTy tys + +mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> CoreExpr) +mkLiftEnv lv [ty] [v] + = do + len <- lengthPA (Var v) + return (Var v, \env body -> Let (NonRec v env) + $ Case len lv (exprType body) [(DEFAULT, [], body)]) + +-- NOTE: this transparently deals with empty environments +mkLiftEnv lv tys vs + = do + (env_tc, env_tyargs) <- lookupPArrayFamInst vty + let [env_con] = tyConDataCons env_tc + + env = Var (dataConWrapId env_con) + `mkTyApps` env_tyargs + `mkVarApps` (lv : vs) + + bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env + in + Case scrut (mkWildId (exprType scrut)) (exprType body) + [(DataAlt env_con, lv : vs, body)] + return (env, bind) + where + vty = mkCoreTupTy tys +