X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectType.hs;h=37d65db91e29be98a950ed0499a13b8494ed66ac;hb=e499cbe9455b359e0325327fcdb57e2c9d621a0e;hp=e75c977ce868bb67a9837d2d5412e8b81e7b6f12;hpb=a139addf4890fc2167949680ead07ab809a9d98b;p=ghc-hetmet.git diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index e75c977..37d65db 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -11,6 +11,7 @@ import VectCore import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons ) import CoreSyn import CoreUtils +import CoreUnfold import MkCore ( mkWildCase ) import BuildTyCl import DataCon @@ -20,9 +21,11 @@ import TypeRep import Coercion import FamInstEnv ( FamInst, mkLocalFamInst ) import OccName +import Id import MkId -import BasicTypes ( StrictnessMark(..), boolToRecFlag ) -import Var ( Var, TyVar ) +import BasicTypes ( HsBang(..), boolToRecFlag, + alwaysInlinePragma, dfunInlinePragma ) +import Var ( Var, TyVar, varType ) import Name ( Name, getOccName ) import NameEnv @@ -37,7 +40,7 @@ import FastString import MonadUtils ( zipWith3M, foldrM, concatMapM ) import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM ) -import Data.List ( inits, tails, zipWith4, zipWith5, zipWith6 ) +import Data.List ( inits, tails, zipWith4, zipWith5 ) -- ---------------------------------------------------------------------------- -- Types @@ -119,26 +122,28 @@ vectTypeEnv env let orig_tcs = keep_tcs ++ conv_tcs vect_tcs = keep_tcs ++ new_tcs - dfuns <- mapM mkPADFun vect_tcs - defTyConPAs (zip vect_tcs dfuns) - reprs <- mapM tyConRepr vect_tcs - repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs - pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs - binds <- sequence (zipWith6 buildTyConBindings orig_tcs - vect_tcs - repr_tcs - pdata_tcs - dfuns - reprs) - - let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs + (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) -> + do + defTyConPAs (zipLazy vect_tcs dfuns') + reprs <- mapM tyConRepr vect_tcs + repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs + pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs + dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs + vect_tcs + repr_tcs + pdata_tcs + reprs + binds <- takeHoisted + return (dfuns, binds, repr_tcs ++ pdata_tcs) + + let all_new_tcs = new_tcs ++ inst_tcs let new_env = extendTypeEnvList env (map ATyCon all_new_tcs ++ [ADataCon dc | tc <- all_new_tcs , dc <- tyConDataCons tc]) - return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds) + return (new_env, map mkLocalFamInst inst_tcs, binds) where tycons = typeEnvTyCons env groups = tyConGroups tycons @@ -197,7 +202,7 @@ vectDataCon dc liftDs $ buildDataCon name' False -- not infix - (map (const NotMarkedStrict) arg_tys) + (map (const HsNoBang) arg_tys) [] -- no labelled fields univ_tvs [] -- no existential tvs for now @@ -332,7 +337,7 @@ buildToPRepr vect_tc repr_tc _ repr wrap_repr_inst = wrapFamInstBody repr_tc ty_args - to_sum arg arg_ty res_ty EmptySum + to_sum _ _ _ EmptySum = do void <- builtin voidVar return $ wrap_repr_inst $ Var void @@ -348,8 +353,7 @@ buildToPRepr vect_tc repr_tc _ repr , repr_cons = cons }) = do alts <- mapM con_alt cons - let ty_args = map Type tys - alts' = [(pat, vars, wrap_repr_inst + let alts' = [(pat, vars, wrap_repr_inst $ mkConApp sum_con (map Type tys ++ [body])) | ((pat, vars, body), sum_con) <- zip alts (tyConDataCons sum_tc)] @@ -400,7 +404,7 @@ buildFromPRepr vect_tc repr_tc _ repr ty_args = mkTyVarTys (tyConTyVars vect_tc) res_ty = mkTyConApp vect_tc ty_args - from_sum expr EmptySum + from_sum _ EmptySum = do dummy <- builtin fromVoidVar return $ Var dummy `App` Type res_ty @@ -419,7 +423,7 @@ buildFromPRepr vect_tc repr_tc _ repr from_con expr (ConRepr con r) = from_prod expr (mkConApp con $ map Type ty_args) r - from_prod expr con EmptyProd = return con + from_prod _ con EmptyProd = return con from_prod expr con (UnaryProd r) = do e <- from_comp expr r @@ -562,7 +566,7 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r [pdata_con] = tyConDataCons pdata_tc - from_sum res_ty res expr EmptySum = return (res, []) + from_sum _ res _ EmptySum = return (res, []) from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc , repr_sel_ty = sel_ty @@ -583,7 +587,7 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r - from_prod res_ty res expr EmptyProd = return (res, []) + from_prod _ res _ EmptyProd = return (res, []) from_prod res_ty res expr (UnaryProd r) = from_comp res_ty res expr r from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc @@ -600,8 +604,8 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r where [ptup_con] = tyConDataCons ptup_tc - from_comp res_ty res expr (Keep _ _) = return (res, [expr]) - from_comp res_ty res expr (Wrap ty) + from_comp _ res expr (Keep _ _) = return (res, [expr]) + from_comp _ res expr (Wrap ty) = do wrap_tc <- builtin wrapTyCon (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty]) @@ -689,7 +693,7 @@ buildPDataDataCon orig_name vect_tc repr_tc repr liftDs $ buildDataCon dc_name False -- not infix - (map (const NotMarkedStrict) comp_tys) + (map (const HsNoBang) comp_tys) [] -- no field labels tvs [] -- no existentials @@ -716,18 +720,12 @@ buildPDataDataCon orig_name vect_tc repr_tc repr comp_ty r = mkPDataType (compOrigType r) -mkPADFun :: TyCon -> VM Var -mkPADFun vect_tc - = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc - -buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> SumRepr - -> VM [(Var, CoreExpr)] -buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun repr +buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr + -> VM Var +buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr = do vectDataConWorkers orig_tc vect_tc pdata_tc - dict <- buildPADict vect_tc prepr_tc pdata_tc repr - binds <- takeHoisted - return $ (dfun, dict) : binds + buildPADict vect_tc prepr_tc pdata_tc repr vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM () vectDataConWorkers orig_tc vect_tc arr_tc @@ -782,46 +780,72 @@ vectDataConWorkers orig_tc vect_tc arr_tc def_worker data_con arg_tys mk_body = do + arity <- polyArity tyvars body <- closedV . inBind orig_worker - . polyAbstract tyvars $ \abstract -> - liftM (abstract . vectorised) + . polyAbstract tyvars $ \args -> + liftM (mkLams (tyvars ++ args) . vectorised) $ buildClosures tyvars [] arg_tys res_ty mk_body - vect_worker <- cloneId mkVectOcc orig_worker (exprType body) + raw_worker <- cloneId mkVectOcc orig_worker (exprType body) + let vect_worker = raw_worker `setIdUnfolding` + mkInlineRule body (Just arity) defGlobalVar orig_worker vect_worker return (vect_worker, body) where orig_worker = dataConWorkId data_con -buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr +buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var buildPADict vect_tc prepr_tc arr_tc repr - = polyAbstract tvs $ \abstract -> + = polyAbstract tvs $ \args -> do - meth_binds <- mapM mk_method paMethods - let meth_exprs = map (Var . fst) meth_binds + method_ids <- mapM (method args) paMethods + + pa_tc <- builtin paTyCon + pa_con <- builtin paDataCon + let dict = mkLams (tvs ++ args) + $ mkConApp pa_con + $ Type inst_ty : map (method_call args) method_ids - pa_dc <- builtin paDataCon - let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs) - body = Let (Rec meth_binds) dict - return . mkInlineMe $ abstract body + dfun_ty = mkForAllTys tvs + $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty]) + + raw_dfun <- newExportedVar dfun_name dfun_ty + let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding pa_con method_ids + `setInlinePragma` dfunInlinePragma + + hoistBinding dfun dict + return dfun where - tvs = tyConTyVars arr_tc + tvs = tyConTyVars vect_tc arg_tys = mkTyVarTys tvs + inst_ty = mkTyConApp vect_tc arg_tys + + dfun_name = mkPADFunOcc (getOccName vect_tc) - mk_method (name, build) + method args (name, build) = localV $ do - body <- build vect_tc prepr_tc arr_tc repr - var <- newLocalVar name (exprType body) - return (var, mkInlineMe body) - -paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)] -paMethods = [(fsLit "dictPRepr", buildPRDict), - (fsLit "toPRepr", buildToPRepr), - (fsLit "fromPRepr", buildFromPRepr), - (fsLit "toArrPRepr", buildToArrPRepr), - (fsLit "fromArrPRepr", buildFromArrPRepr)] + expr <- build vect_tc prepr_tc arr_tc repr + let body = mkLams (tvs ++ args) expr + raw_var <- newExportedVar (method_name name) (exprType body) + let var = raw_var + `setIdUnfolding` mkInlineRule body (Just (length args)) + `setInlinePragma` alwaysInlinePragma + hoistBinding var body + return var + + method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args) + + method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name) + + +paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)] +paMethods = [("dictPRepr", buildPRDict), + ("toPRepr", buildToPRepr), + ("fromPRepr", buildFromPRepr), + ("toArrPRepr", buildToArrPRepr), + ("fromArrPRepr", buildFromArrPRepr)] -- | Split the given tycons into two sets depending on whether they have to be -- converted (first list) or not (second list). The first argument contains