From 27cb0a02d3e4c7a166e2c991e6ad4c09f54a10bc Mon Sep 17 00:00:00 2001 From: Roman Leshchinskiy Date: Fri, 24 Aug 2007 07:19:25 +0000 Subject: [PATCH] Simplify generation of PR dictionaries for products --- compiler/vectorise/VectType.hs | 221 +++++++++++++++++++++++++++++++++------ compiler/vectorise/VectUtils.hs | 102 ------------------ 2 files changed, 190 insertions(+), 133 deletions(-) diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index 10c3bbf..455a8ad 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -208,15 +208,129 @@ buildPReprTyCon orig_tc vect_tc where tyvars = tyConTyVars vect_tc +data TyConRepr = ProdRepr { + repr_prod_arg_tys :: [Type] + , repr_prod_tycon :: TyCon + , repr_prod_data_con :: DataCon + , repr_type :: Type + } + | SumRepr { + repr_tys :: [[Type]] + , repr_prod_tycons :: [Maybe TyCon] + , repr_prod_data_cons :: [Maybe DataCon] + , repr_prod_tys :: [Type] + , repr_sum_tycon :: TyCon + , repr_type :: Type + } + +arrShapeTys :: TyConRepr -> VM [Type] +arrShapeTys (ProdRepr {}) = return [intPrimTy] +arrShapeTys (SumRepr {}) + = do + uarr <- builtin uarrTyCon + return [intPrimTy, mkTyConApp uarr [intTy]] + +arrReprTys :: TyConRepr -> VM [Type] +arrReprTys (ProdRepr { repr_prod_arg_tys = tys }) + = mapM mkPArrayType tys +arrReprTys (SumRepr { repr_tys = tys }) + = concat `liftM` mapM (mapM mkPArrayType) (map mk_prod tys) + where + mk_prod [] = [unitTy] + mk_prod tys = tys + + +mkTyConRepr :: TyCon -> VM TyConRepr +mkTyConRepr vect_tc + | is_product + = let + [prod_arg_tys] = repr_tys + in + do + prod_tycon <- builtin (prodTyCon $ length prod_arg_tys) + let [prod_data_con] = tyConDataCons prod_tycon + + return $ ProdRepr { + repr_prod_arg_tys = prod_arg_tys + , repr_prod_tycon = prod_tycon + , repr_prod_data_con = prod_data_con + , repr_type = mkTyConApp prod_tycon prod_arg_tys + } + + | otherwise + = do + uarr <- builtin uarrTyCon + prod_tycons <- mapM (mk_tycon prodTyCon) repr_tys + let prod_tys = zipWith mk_tc_app_maybe prod_tycons repr_tys + sum_tycon <- builtin (sumTyCon $ length repr_tys) + arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) repr_tys + + return $ SumRepr { + repr_tys = repr_tys + , repr_prod_tycons = prod_tycons + , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons + , repr_prod_tys = prod_tys + , repr_sum_tycon = sum_tycon + , repr_type = mkTyConApp sum_tycon prod_tys + } + where + tyvars = tyConTyVars vect_tc + data_cons = tyConDataCons vect_tc + repr_tys = map dataConRepArgTys data_cons + + is_product | [_] <- data_cons = True + | otherwise = False + + mk_shape uarr = intPrimTy : mk_sel uarr + + mk_sel uarr | is_product = [] + | otherwise = [uarr_int, uarr_int] + where + uarr_int = mkTyConApp uarr [intTy] + + mk_tycon get_tc tys + | n > 1 = builtin (Just . get_tc n) + | otherwise = return Nothing + where n = length tys + + mk_single_datacon tc | [dc] <- tyConDataCons tc = dc + + mk_tc_app_maybe Nothing [] = unitTy + mk_tc_app_maybe Nothing [ty] = ty + mk_tc_app_maybe (Just tc) tys = mkTyConApp tc tys + + arr_repr_elem_tys [] = [unitTy] + arr_repr_elem_tys tys = tys + buildPReprType :: TyCon -> VM Type buildPReprType = liftM repr_type . mkTyConRepr buildToPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr -buildToPRepr (TyConRepr { +buildToPRepr (ProdRepr { + repr_prod_arg_tys = prod_arg_tys + , repr_prod_data_con = prod_data_con + , repr_type = repr_type + }) + vect_tc prepr_tc _ + = do + arg <- newLocalVar FSLIT("x") arg_ty + vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys + + return . Lam arg + . wrapFamInstBody prepr_tc var_tys + $ Case (Var arg) (mkWildId arg_ty) repr_type + [(DataAlt data_con, vars, + mkConApp prod_data_con (map Type prod_arg_tys ++ map Var vars))] + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + arg_ty = mkTyConApp vect_tc var_tys + [data_con] = tyConDataCons vect_tc + +buildToPRepr (SumRepr { repr_tys = repr_tys , repr_prod_data_cons = prod_data_cons , repr_prod_tys = prod_tys - , repr_sum_data_cons = sum_data_cons + , repr_sum_tycon = sum_tycon , repr_type = repr_type }) vect_tc prepr_tc _ @@ -227,16 +341,14 @@ buildToPRepr (TyConRepr { return . Lam arg . wrapFamInstBody prepr_tc var_tys . Case (Var arg) (mkWildId arg_ty) repr_type - . mk_alts data_cons vars + . zipWith4 mk_alt data_cons vars sum_data_cons . zipWith3 mk_prod prod_data_cons repr_tys $ map (map Var) vars where var_tys = mkTyVarTys $ tyConTyVars vect_tc arg_ty = mkTyConApp vect_tc var_tys data_cons = tyConDataCons vect_tc - mk_alts _ _ [] = [(DEFAULT, [], Var unitDataConId)] - mk_alts [dc] [vars] [expr] = [(DataAlt dc, vars, expr)] - mk_alts dcs vars exprs = zipWith4 mk_alt dcs vars sum_data_cons exprs + sum_data_cons = tyConDataCons sum_tycon mk_alt dc vars sum_dc expr = (DataAlt dc, vars, mkConApp sum_dc (map Type prod_tys ++ [expr])) @@ -246,11 +358,34 @@ buildToPRepr (TyConRepr { mk_prod (Just dc) tys exprs = mkConApp dc (map Type tys ++ exprs) buildFromPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr -buildFromPRepr (TyConRepr { +buildFromPRepr (ProdRepr { + repr_prod_arg_tys = prod_arg_tys + , repr_prod_data_con = prod_data_con + , repr_type = repr_type + }) + vect_tc prepr_tc _ + = do + arg_ty <- mkPReprType res_ty + arg <- newLocalVar FSLIT("x") arg_ty + vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys + + return . Lam arg + $ Case (unwrapFamInstScrut prepr_tc var_tys (Var arg)) + (mkWildId repr_type) + res_ty + [(DataAlt prod_data_con, vars, + mkConApp data_con (map Type var_tys ++ map Var vars))] + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + ty_args = map Type var_tys + res_ty = mkTyConApp vect_tc var_tys + [data_con] = tyConDataCons vect_tc + +buildFromPRepr (SumRepr { repr_tys = repr_tys , repr_prod_data_cons = prod_data_cons , repr_prod_tys = prod_tys - , repr_sum_data_cons = sum_data_cons + , repr_sum_tycon = sum_tycon , repr_type = repr_type }) vect_tc prepr_tc _ @@ -259,7 +394,10 @@ buildFromPRepr (TyConRepr { arg <- newLocalVar FSLIT("x") arg_ty liftM (Lam arg - . un_sum (unwrapFamInstScrut prepr_tc var_tys (Var arg))) + . Case (unwrapFamInstScrut prepr_tc var_tys (Var arg)) + (mkWildId repr_type) + res_ty + . zipWith mk_alt sum_data_cons) (sequence $ zipWith4 un_prod data_cons prod_data_cons prod_tys repr_tys) where @@ -268,6 +406,8 @@ buildFromPRepr (TyConRepr { res_ty = mkTyConApp vect_tc var_tys data_cons = tyConDataCons vect_tc + sum_data_cons = tyConDataCons sum_tycon + un_prod dc _ _ [] = do var <- newLocalVar FSLIT("u") unitTy @@ -288,15 +428,28 @@ buildFromPRepr (TyConRepr { return (pv, expr) - un_sum scrut [(var, expr)] = Let (NonRec var scrut) expr - un_sum scrut alts - = Case scrut (mkWildId repr_type) res_ty - $ zipWith mk_alt sum_data_cons alts - mk_alt sum_dc (var, expr) = (DataAlt sum_dc, [var], expr) buildToArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr +{- +buildToArrPRepr (ProdRepr { + repr_prod_arg_tys = prod_arg_tys + , repr_prod_data_con = prod_data_con + , repr_type = repr_type + }) + vect_tc prepr_tc _ + = do + arg_ty <- mkPArratType el_ty + rep_tys <- mapM mkPArrayType prod_arg_tys + + + where + var_tys = mkTyVarTys $ tyConTyVars vect_tc + el_ty = mkTyConApp vect_tc var_tys +-} +buildToArrPRepr _ _ _ _ = return (Var unitDataConId) +{- buildToArrPRepr _ vect_tc prepr_tc arr_tc = do arg_ty <- mkPArrayType el_ty @@ -331,29 +484,38 @@ buildToArrPRepr _ vect_tc prepr_tc arr_tc has_selector | [_] <- data_cons = False | otherwise = True +-} buildFromArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr -buildFromArrPRepr _ vect_tc prepr_tc arr_tc - = mkFromArrPRepr undefined undefined undefined undefined undefined undefined +buildFromArrPRepr _ _ _ _ = return (Var unitDataConId) buildPRDict :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr -buildPRDict (TyConRepr { +buildPRDict (ProdRepr { + repr_prod_arg_tys = prod_arg_tys + , repr_prod_tycon = prod_tycon + }) + vect_tc prepr_tc _ + = do + prs <- mapM mkPR prod_arg_tys + dfun <- prDFunOfTyCon prod_tycon + return $ dfun `mkTyApps` prod_arg_tys `mkApps` prs + +buildPRDict (SumRepr { repr_tys = repr_tys , repr_prod_tycons = prod_tycons , repr_prod_tys = prod_tys - , repr_sum_tycon = repr_sum_tycon + , repr_sum_tycon = sum_tycon }) vect_tc prepr_tc _ = do prs <- mapM (mapM mkPR) repr_tys prod_prs <- sequence $ zipWith3 mk_prod_pr prod_tycons repr_tys prs - sum_pr <- mk_sum_pr prod_prs - prCoerce prepr_tc var_tys sum_pr + sum_dfun <- prDFunOfTyCon sum_tycon + prCoerce prepr_tc var_tys + $ sum_dfun `mkTyApps` prod_tys `mkApps` prod_prs where var_tys = mkTyVarTys $ tyConTyVars vect_tc - Just sum_tycon = repr_sum_tycon - mk_prod_pr _ _ [] = prDFunOfTyCon unitTyCon mk_prod_pr _ _ [pr] = return pr mk_prod_pr (Just tc) tys prs @@ -361,12 +523,6 @@ buildPRDict (TyConRepr { dfun <- prDFunOfTyCon tc return $ dfun `mkTyApps` tys `mkApps` prs - mk_sum_pr [pr] = return pr - mk_sum_pr prs - = do - dfun <- prDFunOfTyCon sum_tycon - return $ dfun `mkTyApps` prod_tys `mkApps` prs - buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc -> do @@ -400,17 +556,20 @@ buildPArrayDataCon orig_name vect_tc repr_tc dc_name <- cloneName mkPArrayDataConOcc orig_name repr <- mkTyConRepr vect_tc - let all_tys = arr_shape_tys repr ++ concat (arr_repr_tys repr) + shape_tys <- arrShapeTys repr + repr_tys <- arrReprTys repr + + let tys = shape_tys ++ repr_tys liftDs $ buildDataCon dc_name False -- not infix - (map (const NotMarkedStrict) all_tys) + (map (const NotMarkedStrict) tys) [] -- no field labels (tyConTyVars vect_tc) [] -- no existentials [] -- no eq spec [] -- no context - all_tys + tys repr_tc mkPADFun :: TyCon -> VM Var diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index e71d2a6..a50b4de 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -4,8 +4,6 @@ module VectUtils ( mkDataConTag, splitClosureTy, - TyConRepr(..), mkTyConRepr, - mkToArrPRepr, mkFromArrPRepr, mkPADictType, mkPArrayType, mkPReprType, parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut, @@ -42,7 +40,6 @@ import BasicTypes ( Boxity(..) ) import Outputable import FastString -import Maybes ( orElse ) import Data.List ( zipWith4 ) import Control.Monad ( liftM, liftM2, zipWithM_ ) @@ -127,105 +124,6 @@ mkBuiltinTyConApps1 get_tc dft tys where mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2] -data TyConRepr = TyConRepr { - repr_tyvars :: [TyVar] - , repr_tys :: [[Type]] - , arr_shape_tys :: [Type] - , arr_repr_tys :: [[Type]] - - , repr_prod_tycons :: [Maybe TyCon] - , repr_prod_data_cons :: [Maybe DataCon] - , repr_prod_tys :: [Type] - , repr_sum_tycon :: Maybe TyCon - , repr_sum_data_cons :: [DataCon] - , repr_type :: Type - } - -mkTyConRepr :: TyCon -> VM TyConRepr -mkTyConRepr vect_tc - = do - uarr <- builtin uarrTyCon - prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys - let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys - sum_tycon <- mk_tycon sumTyCon prod_tys - arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) rep_tys - - return $ TyConRepr { - repr_tyvars = tyvars - , repr_tys = rep_tys - , arr_shape_tys = mk_shape uarr - , arr_repr_tys = arr_repr_tys - - , repr_prod_tycons = prod_tycons - , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons - , repr_prod_tys = prod_tys - , repr_sum_tycon = sum_tycon - , repr_sum_data_cons = fmap tyConDataCons sum_tycon `orElse` [] - , repr_type = mk_tc_app_maybe sum_tycon prod_tys - } - where - tyvars = tyConTyVars vect_tc - data_cons = tyConDataCons vect_tc - rep_tys = map dataConRepArgTys data_cons - - is_product | [_] <- data_cons = True - | otherwise = False - - mk_shape uarr = intPrimTy : mk_sel uarr - - mk_sel uarr | is_product = [] - | otherwise = [uarr_int, uarr_int] - where - uarr_int = mkTyConApp uarr [intTy] - - mk_tycon get_tc tys - | n > 1 = builtin (Just . get_tc n) - | otherwise = return Nothing - where n = length tys - - mk_single_datacon tc | [dc] <- tyConDataCons tc = dc - - mk_tc_app_maybe Nothing [] = unitTy - mk_tc_app_maybe Nothing [ty] = ty - mk_tc_app_maybe (Just tc) tys = mkTyConApp tc tys - - arr_repr_elem_tys [] = [unitTy] - arr_repr_elem_tys tys = tys - -mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr -mkToArrPRepr len sel ess - = do - let mk_sum [(expr, ty)] = return (expr, ty) - mk_sum es - = do - sum_tc <- builtin . sumTyCon $ length es - (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys) - let [sum_rdc] = tyConDataCons sum_rtc - - return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)), - mkTyConApp sum_tc tys) - where - (exprs, tys) = unzip es - - mk_prod [expr] = return (expr, splitPArrayTy (exprType expr)) - mk_prod exprs - = do - prod_tc <- builtin . prodTyCon $ length exprs - (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys) - let [prod_rdc] = tyConDataCons prod_rtc - - return (mkConApp prod_rdc (map Type tys ++ (len : exprs)), - mkTyConApp prod_tc tys) - where - tys = map (splitPArrayTy . exprType) exprs - - liftM fst (mk_sum =<< mapM mk_prod ess) - -mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr - -> VM CoreExpr -mkFromArrPRepr scrut res_ty len sel vars res - = return (Var unitDataConId) - mkClosureType :: Type -> Type -> VM Type mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty] -- 1.7.10.4