X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectType.hs;h=ca5f0c82795f939b55693566e170d20bcdd0dadc;hb=facf3d6c3a2eefb66ec0ecefb0e8b390ca59ac8c;hp=0fe93eb8480f507473db95e4d74734c565b95ff1;hpb=9f28e733dd1b7552cab788e593a2d64005c09f37;p=ghc-hetmet.git diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index 0fe93eb..ca5f0c8 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -224,6 +224,22 @@ data Repr = ProdRepr { , sum_arr_data_con :: DataCon } + | IdRepr Type + + | VoidRepr { + void_tycon :: TyCon + , void_bottom :: CoreExpr + } + +mkVoid :: VM Repr +mkVoid = do + tycon <- builtin voidTyCon + var <- builtin voidVar + return $ VoidRepr { + void_tycon = tycon + , void_bottom = Var var + } + mkProduct :: [Type] -> VM Repr mkProduct tys = do @@ -243,6 +259,11 @@ mkProduct tys where arity = length tys +mkSubProduct :: [Type] -> VM Repr +mkSubProduct [] = mkVoid +mkSubProduct [ty] = return $ IdRepr ty +mkSubProduct tys = mkProduct tys + mkSum :: [Repr] -> VM Repr mkSum [repr] = return repr mkSum reprs @@ -263,51 +284,62 @@ mkSum reprs where arity = length reprs -reprProducts :: Repr -> [Repr] -reprProducts (SumRepr { sum_components = rs }) = rs -reprProducts repr = [repr] - reprType :: Repr -> Type reprType (ProdRepr { prod_tycon = tycon, prod_components = tys }) = mkTyConApp tycon tys reprType (SumRepr { sum_tycon = tycon, sum_components = reprs }) = mkTyConApp tycon (map reprType reprs) +reprType (IdRepr ty) = ty +reprType (VoidRepr { void_tycon = tycon }) = mkTyConApp tycon [] arrReprType :: Repr -> VM Type arrReprType = mkPArrayType . reprType -reprTys :: Repr -> [[Type]] -reprTys (SumRepr { sum_components = prods }) = map prodTys prods -reprTys prod = [prodTys prod] - -prodTys (ProdRepr { prod_components = tys }) = tys - -reprVars :: Repr -> VM [[Var]] -reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys - arrShapeTys :: Repr -> VM [Type] arrShapeTys (SumRepr {}) = do - uarr <- builtin uarrTyCon - return [intPrimTy, mkTyConApp uarr [intTy]] -arrShapeTys repr = return [intPrimTy] + int_arr <- builtin parrayIntPrimTyCon + return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []] +arrShapeTys (ProdRepr {}) = return [intPrimTy] +arrShapeTys (IdRepr _) = return [] +arrShapeTys (VoidRepr {}) = return [intPrimTy] arrShapeVars :: Repr -> VM [Var] arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr -arrReprTys :: Repr -> VM [[Type]] -arrReprTys (SumRepr { sum_components = prods }) - = mapM arrProdTys prods -arrReprTys prod +replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr] +replicateShape (ProdRepr {}) len _ = return [len] +replicateShape (SumRepr {}) len tag = do - tys <- arrProdTys prod + rep <- builtin replicatePAIntPrimVar + up <- builtin upToPAIntPrimVar + return [len, Var rep `mkApps` [len, tag], Var up `App` len] +replicateShape (IdRepr _) _ _ = return [] +replicateShape (VoidRepr {}) len _ = return [len] + +arrReprElemTys :: Repr -> VM [[Type]] +arrReprElemTys (SumRepr { sum_components = prods }) + = mapM arrProdElemTys prods +arrReprElemTys prod@(ProdRepr {}) + = do + tys <- arrProdElemTys prod return [tys] +arrReprElemTys (IdRepr ty) = return [[ty]] +arrReprElemTys (VoidRepr { void_tycon = tycon }) + = return [[mkTyConApp tycon []]] -arrProdTys (ProdRepr { prod_components = tys }) - = mapM mkPArrayType (mk_types tys) - where - mk_types [] = [unitTy] - mk_types tys = tys +arrProdElemTys (ProdRepr { prod_components = [] }) + = do + void <- builtin voidTyCon + return [mkTyConApp void []] +arrProdElemTys (ProdRepr { prod_components = tys }) + = return tys +arrProdElemTys (IdRepr ty) = return [ty] +arrProdElemTys (VoidRepr { void_tycon = tycon }) + = return [mkTyConApp tycon []] + +arrReprTys :: Repr -> VM [[Type]] +arrReprTys repr = mapM (mapM mkPArrayType) =<< arrReprElemTys repr arrReprVars :: Repr -> VM [[Var]] arrReprVars repr @@ -315,8 +347,10 @@ arrReprVars repr mkRepr :: TyCon -> VM Repr mkRepr vect_tc - = mkSum - =<< mapM mkProduct (map dataConRepArgTys $ tyConDataCons vect_tc) + | [tys] <- rep_tys = mkProduct tys + | otherwise = mkSum =<< mapM mkSubProduct rep_tys + where + rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc buildPReprType :: TyCon -> VM Type buildPReprType = liftM reprType . mkRepr @@ -363,6 +397,15 @@ buildToPRepr repr vect_tc prepr_tc _ vars <- mapM (newLocalVar FSLIT("r")) tys return (vars, mkConApp data_con (map Type tys ++ map Var vars)) + prod_alt (IdRepr ty) + = do + var <- newLocalVar FSLIT("y") ty + return ([var], Var var) + + prod_alt (VoidRepr { void_bottom = bottom }) + = return ([], bottom) + + buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildFromPRepr repr vect_tc prepr_tc _ = do @@ -402,6 +445,12 @@ buildFromPRepr repr vect_tc prepr_tc _ return $ Case expr (mkWildId (reprType prod)) res_ty [(DataAlt data_con, vars, con `mkVarApps` vars)] + from_prod (IdRepr _) con expr + = return $ con `App` expr + + from_prod (VoidRepr {}) con expr + = return con + buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildToArrPRepr repr vect_tc prepr_tc arr_tc = do @@ -440,7 +489,7 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc , sum_arr_tycon = tycon , sum_arr_data_con = data_con }) = do - exprs <- zipWithM (to_prod len_var) repr_vars prods + exprs <- zipWithM to_prod repr_vars prods return . wrapFamInstBody tycon tys . mkConApp data_con @@ -448,16 +497,28 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc where tys = map reprType prods - to_repr [len_var] [repr_vars] prod = to_prod len_var repr_vars prod + to_repr [len_var] + [repr_vars] + (ProdRepr { prod_components = tys + , prod_arr_tycon = tycon + , prod_arr_data_con = data_con }) + = return . wrapFamInstBody tycon tys + . mkConApp data_con + $ map Type tys ++ map Var (len_var : repr_vars) - to_prod len_var - repr_vars + to_prod repr_vars@(r : _) (ProdRepr { prod_components = tys , prod_arr_tycon = tycon , prod_arr_data_con = data_con }) - = return . wrapFamInstBody tycon tys - . mkConApp data_con - $ map Type tys ++ map Var (len_var : repr_vars) + = do + len <- lengthPA (Var r) + return . wrapFamInstBody tycon tys + . mkConApp data_con + $ map Type tys ++ len : map Var repr_vars + + to_prod [var] (IdRepr ty) = return (Var var) + to_prod [var] (VoidRepr {}) = return (Var var) + buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildFromArrPRepr repr vect_tc prepr_tc arr_tc @@ -536,7 +597,26 @@ buildFromArrPRepr repr vect_tc prepr_tc arr_tc return $ Case scrut (mkWildId scrut_ty) res_ty [(DataAlt data_con, shape_vars ++ repr_vars, body)] + from_prod (IdRepr ty) + expr + shape_vars + [repr_var] + res_ty + body + = return $ Let (NonRec repr_var expr) body + + from_prod (VoidRepr {}) + expr + shape_vars + [repr_var] + res_ty + body + = return $ Let (NonRec repr_var expr) body + buildPRDictRepr :: Repr -> VM CoreExpr +buildPRDictRepr (VoidRepr { void_tycon = tycon }) + = prDFunOfTyCon tycon +buildPRDictRepr (IdRepr ty) = mkPR ty buildPRDictRepr (ProdRepr { prod_components = tys , prod_tycon = tycon @@ -623,46 +703,12 @@ mkPADFun :: TyCon -> VM Var mkPADFun vect_tc = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc -data Shape = Shape { - shapeReprTys :: [Type] - , shapeStrictness :: [StrictnessMark] - , shapeLength :: [CoreExpr] -> VM CoreExpr - , shapeReplicate :: CoreExpr -> CoreExpr -> VM [CoreExpr] - } - -tyConShape :: TyCon -> VM Shape -tyConShape vect_tc - | isProductTyCon vect_tc - = return $ Shape { - shapeReprTys = [intPrimTy] - , shapeStrictness = [NotMarkedStrict] - , shapeLength = \[len] -> return len - , shapeReplicate = \len _ -> return [len] - } - - | otherwise - = do - repr_ty <- mkPArrayType intTy -- FIXME: we want to unbox this - return $ Shape { - shapeReprTys = [repr_ty] - , shapeStrictness = [MarkedStrict] - , shapeLength = \[sel] -> lengthPA sel - , shapeReplicate = \len n -> do - e <- replicatePA len n - return [e] - } - buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> VM [(Var, CoreExpr)] buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun = do - shape <- tyConShape vect_tc repr <- mkRepr vect_tc - sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc) - orig_dcs - vect_dcs - (inits repr_tys) - (tails repr_tys)) + vectDataConWorkers repr orig_tc vect_tc arr_tc dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun binds <- takeHoisted return $ (dfun, dict) : binds @@ -673,47 +719,73 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun repr_tys = map dataConRepArgTys vect_dcs -vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon - -> DataCon -> DataCon -> [[Type]] -> [[Type]] - -> VM () -vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post) +vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon + -> VM () +vectDataConWorkers repr orig_tc vect_tc arr_tc = do - clo <- closedV - . inBind orig_worker - . polyAbstract tvs $ \abstract -> - liftM (abstract . vectorised) - $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift) - - worker <- cloneId mkVectOcc orig_worker (exprType clo) - hoistBinding worker clo - defGlobalVar orig_worker worker - return () + arr_tys <- arrReprElemTys repr + bs <- sequence + . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys + $ zipWith4 mk_data_con (tyConDataCons vect_tc) + rep_tys + (inits arr_tys) + (tail $ tails arr_tys) + mapM_ (uncurry hoistBinding) bs where - tvs = tyConTyVars vect_tc - arg_tys = mkTyVarTys tvs - res_ty = mkTyConApp vect_tc arg_tys - - orig_worker = dataConWorkId orig_dc - - mk_vect = return . mkConApp vect_dc $ map Type arg_tys - mk_lift = do - len <- newLocalVar FSLIT("n") intPrimTy - arr_tys <- mapM mkPArrayType dc_tys - args <- mapM (newLocalVar FSLIT("xs")) arr_tys - shapes <- shapeReplicate shape - (Var len) - (mkDataConTag vect_dc) - - empty_pre <- mapM emptyPA (concat pre) - empty_post <- mapM emptyPA (concat post) - - return . mkLams (len : args) - . wrapFamInstBody arr_tc arg_tys - . mkConApp arr_dc - $ map Type arg_tys ++ shapes - ++ empty_pre - ++ map Var args - ++ empty_post + tyvars = tyConTyVars vect_tc + var_tys = mkTyVarTys tyvars + ty_args = map Type var_tys + + res_ty = mkTyConApp vect_tc var_tys + + rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc + + [arr_dc] = tyConDataCons arr_tc + + mk_data_con con tys pre post + = liftM2 (,) (vect_data_con con) + (lift_data_con tys (concat pre) + (concat post) + (mkDataConTag con)) + + vect_data_con con = return $ mkConApp con ty_args + lift_data_con tys pre_tys post_tys tag + = do + len <- builtin liftingContext + args <- mapM (newLocalVar FSLIT("xs")) + =<< mapM mkPArrayType tys + + shape <- replicateShape repr (Var len) tag + repr <- mk_arr_repr (Var len) (map Var args) + + pre <- mapM emptyPA pre_tys + post <- mapM emptyPA post_tys + + return . mkLams (len : args) + . wrapFamInstBody arr_tc var_tys + . mkConApp arr_dc + $ ty_args ++ shape ++ pre ++ repr ++ post + + mk_arr_repr len [] + = do + units <- replicatePA len (Var unitDataConId) + return [units] + + mk_arr_repr len arrs = return arrs + + def_worker data_con arg_tys mk_body + = do + body <- closedV + . inBind orig_worker + . polyAbstract tyvars $ \abstract -> + liftM (abstract . vectorised) + $ buildClosures tyvars [] arg_tys res_ty mk_body + + vect_worker <- cloneId mkVectOcc orig_worker (exprType body) + defGlobalVar orig_worker vect_worker + return (vect_worker, body) + where + orig_worker = dataConWorkId data_con buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr buildPADict repr vect_tc prepr_tc arr_tc dfun