Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
index 4ff1711..bef08f7 100644 (file)
@@ -32,6 +32,7 @@ import TysPrim           ( intPrimTy )
 import Unique
 import UniqFM
 import UniqSet
+import Util              ( singleton )
 import Digraph           ( SCC(..), stronglyConnComp )
 
 import Outputable
@@ -226,8 +227,28 @@ data Repr = ProdRepr {
 
           | IdRepr Type
 
-mkProduct :: [Type] -> VM Repr
-mkProduct tys
+          | VoidRepr {
+              void_tycon        :: TyCon
+            , void_bottom       :: CoreExpr
+            }
+
+voidRepr :: VM Repr
+voidRepr
+  = do
+      tycon <- builtin voidTyCon
+      var   <- builtin voidVar
+      return $ VoidRepr {
+                 void_tycon  = tycon
+               , void_bottom = Var var
+               }
+
+unboxedProductRepr :: [Type] -> VM Repr
+unboxedProductRepr []   = voidRepr
+unboxedProductRepr [ty] = return $ IdRepr ty
+unboxedProductRepr tys  = boxedProductRepr tys
+
+boxedProductRepr :: [Type] -> VM Repr
+boxedProductRepr tys
   = do
       tycon <- builtin (prodTyCon arity)
       let [data_con] = tyConDataCons tycon
@@ -245,13 +266,10 @@ mkProduct tys
   where
     arity = length tys
 
-mkSubProduct :: [Type] -> VM Repr
-mkSubProduct [ty] = return $ IdRepr ty
-mkSubProduct tys  = mkProduct tys
-
-mkSum :: [Repr] -> VM Repr
-mkSum [repr] = return repr
-mkSum reprs
+sumRepr :: [Repr] -> VM Repr
+sumRepr []     = voidRepr
+sumRepr [repr] = boxRepr repr
+sumRepr reprs
   = do
       tycon <- builtin (sumTyCon arity)
       (arr_tycon, _) <- parrayReprTyCon
@@ -269,12 +287,22 @@ mkSum reprs
   where
     arity = length reprs
 
+splitSumRepr :: Repr -> [Repr]
+splitSumRepr (SumRepr { sum_components = reprs }) = reprs
+splitSumRepr repr                                 = [repr]
+
+boxRepr :: Repr -> VM Repr
+boxRepr (VoidRepr {}) = boxedProductRepr []
+boxRepr (IdRepr ty)   = boxedProductRepr [ty]
+boxRepr repr          = return repr
+
 reprType :: Repr -> Type
 reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
   = mkTyConApp tycon tys
 reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
   = mkTyConApp tycon (map reprType reprs)
 reprType (IdRepr ty) = ty
+reprType (VoidRepr { void_tycon = tycon }) = mkTyConApp tycon []
 
 arrReprType :: Repr -> VM Type
 arrReprType = mkPArrayType . reprType
@@ -286,6 +314,7 @@ arrShapeTys (SumRepr  {})
       return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
 arrShapeTys (ProdRepr {}) = return [intPrimTy]
 arrShapeTys (IdRepr _)    = return []
+arrShapeTys (VoidRepr {}) = return [intPrimTy]
 
 arrShapeVars :: Repr -> VM [Var]
 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
@@ -298,31 +327,44 @@ replicateShape (SumRepr {})  len tag
       up  <- builtin upToPAIntPrimVar
       return [len, Var rep `mkApps` [len, tag], Var up `App` len]
 replicateShape (IdRepr _) _ _ = return []
-
-arrReprElemTys :: Repr -> [[Type]]
-arrReprElemTys (SumRepr { sum_components = prods })
-  = map arrProdElemTys prods
-arrReprElemTys prod@(ProdRepr {})
-  = [arrProdElemTys prod]
-arrReprElemTys (IdRepr ty) = [[ty]]
-
-arrProdElemTys (ProdRepr { prod_components = [] })
-  = [unitTy]
-arrProdElemTys (ProdRepr { prod_components = tys })
-  = tys
-arrProdElemTys (IdRepr ty) = [ty]
-
-arrReprTys :: Repr -> VM [[Type]]
-arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
+replicateShape (VoidRepr {}) len _ = return [len]
+
+emptyArrRepr :: Repr -> VM [CoreExpr]
+emptyArrRepr (SumRepr { sum_components = prods })
+  = liftM concat $ mapM emptyArrRepr prods
+emptyArrRepr (ProdRepr { prod_components = [] })
+  = return [Var unitDataConId]
+emptyArrRepr (ProdRepr { prod_components = tys })
+  = mapM emptyPA tys
+emptyArrRepr (IdRepr ty)
+  = liftM singleton $ emptyPA ty
+emptyArrRepr (VoidRepr { void_tycon = tycon })
+  = liftM singleton $ emptyPA (mkTyConApp tycon [])
+
+arrReprTys :: Repr -> VM [Type]
+arrReprTys (SumRepr { sum_components = reprs })
+  = liftM concat $ mapM arrReprTys reprs
+arrReprTys (ProdRepr { prod_components = [] })
+  = return [unitTy]
+arrReprTys (ProdRepr { prod_components = tys })
+  = mapM mkPArrayType tys
+arrReprTys (IdRepr ty)
+  = liftM singleton $ mkPArrayType ty
+arrReprTys (VoidRepr { void_tycon = tycon })
+  = liftM singleton $ mkPArrayType (mkTyConApp tycon [])
+
+arrReprTys' :: Repr -> VM [[Type]]
+arrReprTys' (SumRepr { sum_components = reprs })
+  = mapM arrReprTys reprs
+arrReprTys' repr = liftM singleton $ arrReprTys repr
 
 arrReprVars :: Repr -> VM [[Var]]
 arrReprVars repr
-  = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr
+  = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys' repr
 
 mkRepr :: TyCon -> VM Repr
 mkRepr vect_tc
-  | [tys] <- rep_tys = mkProduct tys
-  | otherwise        = mkSum =<< mapM mkSubProduct rep_tys
+  = sumRepr =<< mapM unboxedProductRepr rep_tys
   where
     rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
 
@@ -376,6 +418,10 @@ buildToPRepr repr vect_tc prepr_tc _
           var <- newLocalVar FSLIT("y") ty
           return ([var], Var var)
 
+    prod_alt (VoidRepr { void_bottom = bottom })
+      = return ([], bottom)
+
+
 buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildFromPRepr repr vect_tc prepr_tc _
   = do
@@ -418,6 +464,9 @@ buildFromPRepr repr vect_tc prepr_tc _
     from_prod (IdRepr _) con expr
        = return $ con `App` expr
 
+    from_prod (VoidRepr {}) con expr
+       = return con
+
 buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildToArrPRepr repr vect_tc prepr_tc arr_tc
   = do
@@ -483,8 +532,9 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc
                  . mkConApp data_con
                  $ map Type tys ++ len : map Var repr_vars
 
-    to_prod [var] (IdRepr ty)
-      = return (Var var)
+    to_prod [var] (IdRepr ty)   = return (Var var)
+    to_prod [var] (VoidRepr {}) = return (Var var)
+
 
 buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildFromArrPRepr repr vect_tc prepr_tc arr_tc
@@ -571,7 +621,17 @@ buildFromArrPRepr repr vect_tc prepr_tc arr_tc
               body
       = return $ Let (NonRec repr_var expr) body
 
+    from_prod (VoidRepr {})
+              expr
+              shape_vars
+              [repr_var]
+              res_ty
+              body
+      = return $ Let (NonRec repr_var expr) body
+
 buildPRDictRepr :: Repr -> VM CoreExpr
+buildPRDictRepr (VoidRepr { void_tycon = tycon })
+  = prDFunOfTyCon tycon
 buildPRDictRepr (IdRepr ty) = mkPR ty
 buildPRDictRepr (ProdRepr {
                    prod_components = tys
@@ -642,7 +702,7 @@ buildPArrayDataCon orig_name vect_tc repr_tc
       shape_tys <- arrShapeTys repr
       repr_tys  <- arrReprTys  repr
 
-      let tys = shape_tys ++ concat repr_tys
+      let tys = shape_tys ++ repr_tys
 
       liftDs $ buildDataCon dc_name
                             False                  -- not infix
@@ -683,8 +743,8 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
                                  rep_tys
-                                 (inits arr_tys)
-                                 (tail $ tails arr_tys)
+                                 (inits reprs)
+                                 (tail $ tails reprs)
       mapM_ (uncurry hoistBinding) bs
   where
     tyvars   = tyConTyVars vect_tc
@@ -694,18 +754,16 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
     res_ty   = mkTyConApp vect_tc var_tys
 
     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
-    arr_tys  = arrReprElemTys repr
+    reprs    = splitSumRepr repr
 
     [arr_dc] = tyConDataCons arr_tc
 
     mk_data_con con tys pre post
       = liftM2 (,) (vect_data_con con)
-                   (lift_data_con tys (concat pre)
-                                      (concat post)
-                                      (mkDataConTag con))
+                   (lift_data_con tys pre post (mkDataConTag con))
 
     vect_data_con con = return $ mkConApp con ty_args
-    lift_data_con tys pre_tys post_tys tag
+    lift_data_con tys pre_reprs post_reprs tag
       = do
           len  <- builtin liftingContext
           args <- mapM (newLocalVar FSLIT("xs"))
@@ -714,8 +772,8 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
           shape <- replicateShape repr (Var len) tag
           repr  <- mk_arr_repr (Var len) (map Var args)
 
-          pre   <- mapM emptyPA pre_tys
-          post  <- mapM emptyPA post_tys
+          pre   <- liftM concat $ mapM emptyArrRepr pre_reprs
+          post  <- liftM concat $ mapM emptyArrRepr post_reprs
 
           return . mkLams (len : args)
                  . wrapFamInstBody arr_tc var_tys