Rewrite vectorisation of product DataCon workers
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 29 Aug 2007 14:54:46 +0000 (14:54 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 29 Aug 2007 14:54:46 +0000 (14:54 +0000)
compiler/vectorise/VectType.hs

index 0fe93eb..b238199 100644 (file)
@@ -295,19 +295,22 @@ arrShapeTys repr = return [intPrimTy]
 arrShapeVars :: Repr -> VM [Var]
 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
 
-arrReprTys :: Repr -> VM [[Type]]
-arrReprTys (SumRepr { sum_components = prods })
-  = mapM arrProdTys prods
-arrReprTys prod
-  = do
-      tys <- arrProdTys prod
-      return [tys]
+replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
+replicateShape (ProdRepr {}) len _ = return [len]
 
-arrProdTys (ProdRepr { prod_components = tys })
-  = mapM mkPArrayType (mk_types tys)
-  where
-    mk_types []  = [unitTy]
-    mk_types tys = tys
+arrReprElemTys :: Repr -> [[Type]]
+arrReprElemTys (SumRepr { sum_components = prods })
+  = map arrProdElemTys prods
+arrReprElemTys prod@(ProdRepr {})
+  = [arrProdElemTys prod]
+
+arrProdElemTys (ProdRepr { prod_components = [] })
+  = [unitTy]
+arrProdElemTys (ProdRepr { prod_components = tys })
+  = tys
+
+arrReprTys :: Repr -> VM [[Type]]
+arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
 
 arrReprVars :: Repr -> VM [[Var]]
 arrReprVars repr
@@ -658,11 +661,7 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
   = do
       shape <- tyConShape vect_tc
       repr  <- mkRepr vect_tc
-      sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
-                          orig_dcs
-                          vect_dcs
-                          (inits repr_tys)
-                          (tails repr_tys))
+      vectDataConWorkers repr orig_tc vect_tc arr_tc
       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
       binds <- takeHoisted
       return $ (dfun, dict) : binds
@@ -673,6 +672,75 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
 
     repr_tys = map dataConRepArgTys vect_dcs
 
+vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
+                   -> VM ()
+vectDataConWorkers repr orig_tc vect_tc arr_tc
+  = do
+      bs <- sequence
+          . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
+          $ zipWith4 mk_data_con (tyConDataCons vect_tc)
+                                 rep_tys
+                                 (inits arr_tys)
+                                 (tail $ tails arr_tys)
+      mapM_ (uncurry hoistBinding) bs
+  where
+    tyvars   = tyConTyVars vect_tc
+    var_tys  = mkTyVarTys tyvars
+    ty_args  = map Type var_tys
+
+    res_ty   = mkTyConApp vect_tc var_tys
+
+    rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
+    arr_tys  = arrReprElemTys repr
+
+    [arr_dc] = tyConDataCons arr_tc
+
+    mk_data_con con tys pre post
+      = liftM2 (,) (vect_data_con con)
+                   (lift_data_con tys (concat pre)
+                                      (concat post)
+                                      (mkDataConTag con))
+                   
+
+    vect_data_con con = return $ mkConApp con ty_args
+    lift_data_con tys pre_tys post_tys tag
+      = do
+          len  <- builtin liftingContext
+          args <- mapM (newLocalVar FSLIT("xs"))
+                  =<< mapM mkPArrayType tys
+      
+          shape <- replicateShape repr (Var len) tag
+          repr  <- mk_arr_repr (Var len) (map Var args)
+      
+          pre   <- mapM emptyPA pre_tys
+          post  <- mapM emptyPA post_tys
+
+          return . mkLams (len : args)
+                 . wrapFamInstBody arr_tc var_tys
+                 . mkConApp arr_dc
+                 $ ty_args ++ shape ++ pre ++ repr ++ post
+
+    mk_arr_repr len []
+      = do
+          units <- replicatePA len (Var unitDataConId)
+          return [units]
+
+    mk_arr_repr len arrs = return arrs
+
+    def_worker data_con arg_tys mk_body
+      = do
+          body <- closedV
+                . inBind orig_worker
+                . polyAbstract tyvars $ \abstract ->
+                  liftM (abstract . vectorised)
+                $ buildClosures tyvars [] arg_tys res_ty mk_body
+
+          vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
+          defGlobalVar orig_worker vect_worker
+          return (vect_worker, body)
+      where
+        orig_worker = dataConWorkId data_con
+
 vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
                   -> DataCon -> DataCon -> [[Type]] -> [[Type]]
                   -> VM ()