X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectUtils.hs;h=0d5585fc043160cbefed49b41fe14796db6f323f;hb=bee06bad431d372bd862b5c6e921d8fc87eaffc9;hp=3bf97fa7ffbe2f2cad7e29c38b90926c44c09cd7;hpb=3f6a74eafcabc1f8d496937a33ec92e7b416f989;p=ghc-hetmet.git diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index 3bf97fa..0d5585f 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -12,6 +12,7 @@ module VectUtils ( prDFunOfTyCon, paDictArgType, paDictOfType, paDFunType, paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA, + zipScalars, scalarClosure, polyAbstract, polyApply, polyVApply, hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted, buildClosure, buildClosures, @@ -30,7 +31,6 @@ import TypeRep import TyCon import DataCon import Var -import Id ( mkWildId ) import MkId ( unwrapFamInstScrut ) import TysWiredIn import BasicTypes ( Boxity(..) ) @@ -57,8 +57,8 @@ collectAnnTypeBinders expr = go [] expr collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann) collectAnnValBinders expr = go [] expr where - go bs (_, AnnLam b e) | isIdVar b = go (b:bs) e - go bs e = (reverse bs, e) + go bs (_, AnnLam b e) | isId b = go (b:bs) e + go bs e = (reverse bs, e) isAnnTypeArg :: AnnExpr b ann -> Bool isAnnTypeArg (_, AnnType _) = True @@ -271,6 +271,24 @@ liftPA x lc <- builtin liftingContext replicatePA (Var lc) x +zipScalars :: [Type] -> Type -> VM CoreExpr +zipScalars arg_tys res_ty + = do + scalar <- builtin scalarClass + (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args + zipf <- builtin (scalarZip $ length arg_tys) + return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns + where + ty_args = arg_tys ++ [res_ty] + +scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr +scalarClosure arg_tys res_ty scalar_fun array_fun + = do + ctr <- builtin (closureCtrFun $ length arg_tys) + pas <- mapM paDictOfType (init arg_tys) + return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty]) + `mkApps` (pas ++ [scalar_fun, array_fun]) + newLocalVVar :: FastString -> Type -> VM VVar newLocalVVar fs vty = do @@ -376,12 +394,13 @@ buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr buildClosures _ _ [] _ mk_body = mk_body buildClosures tvs vars [arg_ty] res_ty mk_body - = buildClosure tvs vars arg_ty res_ty mk_body + = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body) buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body = do res_ty' <- mkClosureTypes arg_tys res_ty arg <- newLocalVVar (fsLit "x") arg_ty - buildClosure tvs vars arg_ty res_ty' + liftM vInlineMe + . buildClosure tvs vars arg_ty res_ty' . hoistPolyVExpr tvs $ do lc <- builtin liftingContext @@ -406,7 +425,7 @@ buildClosure tvs vars arg_ty res_ty mk_body body <- mk_body body' <- bind (vVar env_bndr) (vVarApps lc body (vars ++ [arg_bndr])) - return (vLamsWithoutLC [env_bndr, arg_bndr] body') + return . vInlineMe $ vLamsWithoutLC [env_bndr, arg_bndr] body' mkClosure arg_ty res_ty env_ty fn env @@ -430,7 +449,7 @@ mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExp mkVectEnv [] [] = (unitTy, Var unitDataConId, \_ body -> body) mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body) mkVectEnv tys vs = (ty, mkCoreTup (map Var vs), - \env body -> Case env (mkWildId ty) (exprType body) + \env body -> mkWildCase env ty (exprType body) [(DataAlt (tupleCon Boxed (length vs)), vs, body)]) where ty = mkCoreTupTy tys @@ -460,7 +479,7 @@ mkLiftEnv lc tys vs bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env in - return $ Case scrut (mkWildId (exprType scrut)) + return $ mkWildCase scrut (exprType scrut) (exprType body) [(DataAlt env_con, lc : bndrs, body)] return (env, bind)