From bfddbe303f56f1e96b0e4820986699768738beb4 Mon Sep 17 00:00:00 2001 From: Roman Leshchinskiy Date: Wed, 29 Aug 2007 14:54:46 +0000 Subject: [PATCH] Rewrite vectorisation of product DataCon workers --- compiler/vectorise/VectType.hs | 102 +++++++++++++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 17 deletions(-) diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index 0fe93eb..b238199 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -295,19 +295,22 @@ arrShapeTys repr = return [intPrimTy] arrShapeVars :: Repr -> VM [Var] arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr -arrReprTys :: Repr -> VM [[Type]] -arrReprTys (SumRepr { sum_components = prods }) - = mapM arrProdTys prods -arrReprTys prod - = do - tys <- arrProdTys prod - return [tys] +replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr] +replicateShape (ProdRepr {}) len _ = return [len] -arrProdTys (ProdRepr { prod_components = tys }) - = mapM mkPArrayType (mk_types tys) - where - mk_types [] = [unitTy] - mk_types tys = tys +arrReprElemTys :: Repr -> [[Type]] +arrReprElemTys (SumRepr { sum_components = prods }) + = map arrProdElemTys prods +arrReprElemTys prod@(ProdRepr {}) + = [arrProdElemTys prod] + +arrProdElemTys (ProdRepr { prod_components = [] }) + = [unitTy] +arrProdElemTys (ProdRepr { prod_components = tys }) + = tys + +arrReprTys :: Repr -> VM [[Type]] +arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys arrReprVars :: Repr -> VM [[Var]] arrReprVars repr @@ -658,11 +661,7 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun = do shape <- tyConShape vect_tc repr <- mkRepr vect_tc - sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc) - orig_dcs - vect_dcs - (inits repr_tys) - (tails repr_tys)) + vectDataConWorkers repr orig_tc vect_tc arr_tc dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun binds <- takeHoisted return $ (dfun, dict) : binds @@ -673,6 +672,75 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun repr_tys = map dataConRepArgTys vect_dcs +vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon + -> VM () +vectDataConWorkers repr orig_tc vect_tc arr_tc + = do + bs <- sequence + . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys + $ zipWith4 mk_data_con (tyConDataCons vect_tc) + rep_tys + (inits arr_tys) + (tail $ tails arr_tys) + mapM_ (uncurry hoistBinding) bs + where + tyvars = tyConTyVars vect_tc + var_tys = mkTyVarTys tyvars + ty_args = map Type var_tys + + res_ty = mkTyConApp vect_tc var_tys + + rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc + arr_tys = arrReprElemTys repr + + [arr_dc] = tyConDataCons arr_tc + + mk_data_con con tys pre post + = liftM2 (,) (vect_data_con con) + (lift_data_con tys (concat pre) + (concat post) + (mkDataConTag con)) + + + vect_data_con con = return $ mkConApp con ty_args + lift_data_con tys pre_tys post_tys tag + = do + len <- builtin liftingContext + args <- mapM (newLocalVar FSLIT("xs")) + =<< mapM mkPArrayType tys + + shape <- replicateShape repr (Var len) tag + repr <- mk_arr_repr (Var len) (map Var args) + + pre <- mapM emptyPA pre_tys + post <- mapM emptyPA post_tys + + return . mkLams (len : args) + . wrapFamInstBody arr_tc var_tys + . mkConApp arr_dc + $ ty_args ++ shape ++ pre ++ repr ++ post + + mk_arr_repr len [] + = do + units <- replicatePA len (Var unitDataConId) + return [units] + + mk_arr_repr len arrs = return arrs + + def_worker data_con arg_tys mk_body + = do + body <- closedV + . inBind orig_worker + . polyAbstract tyvars $ \abstract -> + liftM (abstract . vectorised) + $ buildClosures tyvars [] arg_tys res_ty mk_body + + vect_worker <- cloneId mkVectOcc orig_worker (exprType body) + defGlobalVar orig_worker vect_worker + return (vect_worker, body) + where + orig_worker = dataConWorkId data_con + vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon -> DataCon -> DataCon -> [[Type]] -> [[Type]] -> VM () -- 1.7.10.4