X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectType.hs;h=aa0eae2f1035a7a3de25555af6be180b76a30ab7;hb=3b962ce87e2dbf6bdc1f3d1e083a74e5a9467665;hp=e528aae42030d8d76cc30066e4a3a0725481067e;hpb=742db8bde59c1175a50e5045332f05ec22d12e80;p=ghc-hetmet.git diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index e528aae..aa0eae2 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -37,7 +37,7 @@ import Digraph ( SCC(..), stronglyConnComp ) import Outputable import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_ ) -import Data.List ( inits, tails, zipWith4 ) +import Data.List ( inits, tails, zipWith4, zipWith5 ) -- ---------------------------------------------------------------------------- -- Types @@ -101,8 +101,12 @@ vectTypeEnv env parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs dfuns <- mapM mkPADFun vect_tcs defTyConPAs (zip vect_tcs dfuns) - binds <- sequence (zipWith4 buildTyConBindings orig_tcs vect_tcs parr_tcs dfuns) - + binds <- sequence (zipWith5 buildTyConBindings orig_tcs + vect_tcs + repr_tcs + parr_tcs + dfuns) + let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs let new_env = extendTypeEnvList env @@ -195,7 +199,7 @@ buildPReprTyCon :: TyCon -> TyCon -> VM TyCon buildPReprTyCon orig_tc vect_tc = do name <- cloneName mkPReprTyConOcc (tyConName orig_tc) - rhs_ty <- buildPReprRhsTy vect_tc + rhs_ty <- buildPReprType vect_tc prepr_tc <- builtin preprTyCon liftDs $ buildSynTyCon name tyvars @@ -204,13 +208,93 @@ buildPReprTyCon orig_tc vect_tc where tyvars = tyConTyVars vect_tc -buildPReprRhsTy :: TyCon -> VM Type -buildPReprRhsTy = buildPReprTy . map dataConRepArgTys . tyConDataCons +buildPReprType :: TyCon -> VM Type +buildPReprType = mkPRepr . map dataConRepArgTys . tyConDataCons + +buildToPRepr :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildToPRepr _ vect_tc prepr_tc _ + = do + arg <- newLocalVar FSLIT("x") arg_ty + bndrss <- mapM (mapM (newLocalVar FSLIT("x"))) rep_tys + (alt_bodies, res_ty) <- mkToPRepr $ map (map Var) bndrss + + return . Lam arg + . wrapFamInstBody prepr_tc var_tys + . Case (Var arg) (mkWildId arg_ty) res_ty + $ zipWith3 mk_alt data_cons bndrss alt_bodies + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + arg_ty = mkTyConApp vect_tc var_tys + data_cons = tyConDataCons vect_tc + rep_tys = map dataConRepArgTys data_cons + + mk_alt data_con bndrs body = (DataAlt data_con, bndrs, body) + +buildToArrPRepr :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildToArrPRepr _ vect_tc prepr_tc arr_tc + = do + arg_ty <- mkPArrayType el_ty + rep_tys <- mapM (mapM mkPArrayType) rep_el_tys + + arg <- newLocalVar FSLIT("xs") arg_ty + bndrss <- mapM (mapM (newLocalVar FSLIT("ys"))) rep_tys + len <- newLocalVar FSLIT("len") intPrimTy + sel <- newLocalVar FSLIT("sel") =<< mkPArrayType intTy + + let add_sel xs | has_selector = sel : xs + | otherwise = xs + + all_bndrs = len : add_sel (concat bndrss) -buildPReprTy :: [[Type]] -> VM Type -buildPReprTy tys = mkPlusTypes unitTy - =<< mapM (mkCrossTypes unitTy) - =<< mapM (mapM mkEmbedType) tys + res <- parrayCoerce prepr_tc var_tys + =<< mkToArrPRepr (Var len) (Var sel) (map (map Var) bndrss) + res_ty <- mkPArrayType =<< mkPReprType el_ty + + return . Lam arg + $ Case (unwrapFamInstScrut arr_tc var_tys (Var arg)) + (mkWildId (mkTyConApp arr_tc var_tys)) + res_ty + [(DataAlt arr_dc, all_bndrs, res)] + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + el_ty = mkTyConApp vect_tc var_tys + data_cons = tyConDataCons vect_tc + rep_el_tys = map dataConRepArgTys data_cons + + [arr_dc] = tyConDataCons arr_tc + + has_selector | [_] <- data_cons = False + | otherwise = True + + +buildFromPRepr :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildFromPRepr _ vect_tc prepr_tc _ + = do + arg_ty <- mkPReprType res_ty + arg <- newLocalVar FSLIT("x") arg_ty + alts <- mapM mk_alt data_cons + body <- mkFromPRepr (unwrapFamInstScrut prepr_tc var_tys (Var arg)) + res_ty alts + return $ Lam arg body + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + res_ty = mkTyConApp vect_tc var_tys + data_cons = tyConDataCons vect_tc + + mk_alt dc = do + bndrs <- mapM (newLocalVar FSLIT("x")) $ dataConRepArgTys dc + return (bndrs, mkConApp dc (map Type var_tys ++ map Var bndrs)) + +buildFromArrPRepr :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildFromArrPRepr _ vect_tc prepr_tc arr_tc + = mkFromArrPRepr undefined undefined undefined undefined undefined undefined + +buildPRDict :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildPRDict _ vect_tc prepr_tc _ + = prCoerce prepr_tc var_tys + =<< prDictOfType (mkTyConApp prepr_tc var_tys) + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc -> @@ -293,8 +377,9 @@ tyConShape vect_tc return [e] } -buildTyConBindings :: TyCon -> TyCon -> TyCon -> Var -> VM [(Var, CoreExpr)] -buildTyConBindings orig_tc vect_tc arr_tc dfun +buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var + -> VM [(Var, CoreExpr)] +buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun = do shape <- tyConShape vect_tc sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc) @@ -302,7 +387,7 @@ buildTyConBindings orig_tc vect_tc arr_tc dfun vect_dcs (inits repr_tys) (tails repr_tys)) - dict <- buildPADict shape vect_tc arr_tc dfun + dict <- buildPADict shape vect_tc prepr_tc arr_tc dfun binds <- takeHoisted return $ (dfun, dict) : binds where @@ -354,8 +439,8 @@ vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post ++ map Var args ++ empty_post -buildPADict :: Shape -> TyCon -> TyCon -> Var -> VM CoreExpr -buildPADict shape vect_tc arr_tc dfun +buildPADict :: Shape -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr +buildPADict shape vect_tc prepr_tc arr_tc dfun = polyAbstract tvs $ \abstract -> do meth_binds <- mapM (mk_method shape) paMethods @@ -372,15 +457,18 @@ buildPADict shape vect_tc arr_tc dfun mk_method shape (name, build) = localV $ do - body <- build shape vect_tc arr_tc + body <- build shape vect_tc prepr_tc arr_tc var <- newLocalVar name (exprType body) return (var, mkInlineMe body) -paMethods = [(FSLIT("lengthPA"), buildLengthPA), - (FSLIT("replicatePA"), buildReplicatePA)] - -buildLengthPA :: Shape -> TyCon -> TyCon -> VM CoreExpr -buildLengthPA shape vect_tc arr_tc +paMethods = [(FSLIT("toPRepr"), buildToPRepr), + (FSLIT("fromPRepr"), buildFromPRepr), + (FSLIT("toArrPRepr"), buildToArrPRepr), + (FSLIT("fromArrPRepr"), buildFromArrPRepr), + (FSLIT("dictPRepr"), buildPRDict)] + +buildLengthPA :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildLengthPA shape vect_tc _ arr_tc = do parr_ty <- mkPArrayType (mkTyConApp vect_tc arg_tys) arg <- newLocalVar FSLIT("xs") parr_ty @@ -428,8 +516,8 @@ buildLengthPA shape vect_tc arr_tc -- -- -buildReplicatePA :: Shape -> TyCon -> TyCon -> VM CoreExpr -buildReplicatePA shape vect_tc arr_tc +buildReplicatePA :: Shape -> TyCon -> TyCon -> TyCon -> VM CoreExpr +buildReplicatePA shape vect_tc _ arr_tc = do len_var <- newLocalVar FSLIT("n") intPrimTy val_var <- newLocalVar FSLIT("x") val_ty