Number data constructors from 0 when vectorising
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
index 6e10dee..dc0b351 100644 (file)
@@ -32,6 +32,7 @@ import TysPrim           ( intPrimTy )
 import Unique
 import UniqFM
 import UniqSet
+import Util              ( singleton )
 import Digraph           ( SCC(..), stronglyConnComp )
 
 import Outputable
@@ -286,6 +287,10 @@ sumRepr reprs
   where
     arity = length reprs
 
+splitSumRepr :: Repr -> [Repr]
+splitSumRepr (SumRepr { sum_components = reprs }) = reprs
+splitSumRepr repr                                 = [repr]
+
 boxRepr :: Repr -> VM Repr
 boxRepr (VoidRepr {}) = boxedProductRepr []
 boxRepr (IdRepr ty)   = boxedProductRepr [ty]
@@ -324,33 +329,38 @@ replicateShape (SumRepr {})  len tag
 replicateShape (IdRepr _) _ _ = return []
 replicateShape (VoidRepr {}) len _ = return [len]
 
-arrReprElemTys :: Repr -> VM [[Type]]
-arrReprElemTys (SumRepr { sum_components = prods })
-  = mapM arrProdElemTys prods
-arrReprElemTys prod@(ProdRepr {})
-  = do
-      tys <- arrProdElemTys prod
-      return [tys]
-arrReprElemTys (IdRepr ty) = return [[ty]]
-arrReprElemTys (VoidRepr { void_tycon = tycon })
-  = return [[mkTyConApp tycon []]]
-
-arrProdElemTys (ProdRepr { prod_components = [] })
-  = do
-      void <- builtin voidTyCon
-      return [mkTyConApp void []]
-arrProdElemTys (ProdRepr { prod_components = tys })
-  = return tys
-arrProdElemTys (IdRepr ty) = return [ty]
-arrProdElemTys (VoidRepr { void_tycon = tycon })
-  = return [mkTyConApp tycon []]
-
-arrReprTys :: Repr -> VM [[Type]]
-arrReprTys repr = mapM (mapM mkPArrayType) =<< arrReprElemTys repr
+emptyArrRepr :: Repr -> VM [CoreExpr]
+emptyArrRepr (SumRepr { sum_components = prods })
+  = liftM concat $ mapM emptyArrRepr prods
+emptyArrRepr (ProdRepr { prod_components = [] })
+  = return [Var unitDataConId]
+emptyArrRepr (ProdRepr { prod_components = tys })
+  = mapM emptyPA tys
+emptyArrRepr (IdRepr ty)
+  = liftM singleton $ emptyPA ty
+emptyArrRepr (VoidRepr { void_tycon = tycon })
+  = liftM singleton $ emptyPA (mkTyConApp tycon [])
+
+arrReprTys :: Repr -> VM [Type]
+arrReprTys (SumRepr { sum_components = reprs })
+  = liftM concat $ mapM arrReprTys reprs
+arrReprTys (ProdRepr { prod_components = [] })
+  = return [unitTy]
+arrReprTys (ProdRepr { prod_components = tys })
+  = mapM mkPArrayType tys
+arrReprTys (IdRepr ty)
+  = liftM singleton $ mkPArrayType ty
+arrReprTys (VoidRepr { void_tycon = tycon })
+  = liftM singleton $ mkPArrayType (mkTyConApp tycon [])
+
+arrReprTys' :: Repr -> VM [[Type]]
+arrReprTys' (SumRepr { sum_components = reprs })
+  = mapM arrReprTys reprs
+arrReprTys' repr = liftM singleton $ arrReprTys repr
 
 arrReprVars :: Repr -> VM [[Var]]
 arrReprVars repr
-  = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr
+  = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys' repr
 
 mkRepr :: TyCon -> VM Repr
 mkRepr vect_tc
@@ -382,7 +392,7 @@ buildToPRepr repr vect_tc prepr_tc _
                      , sum_tycon      = tycon })
             expr
       = do
-          (vars, bodies) <- mapAndUnzipM prod_alt prods
+          (vars, bodies) <- mapAndUnzipM to_unboxed prods
           return . Case expr (mkWildId (exprType expr)) res_ty
                  $ zipWith4 mk_alt cons vars (tyConDataCons tycon) bodies
       where
@@ -393,22 +403,22 @@ buildToPRepr repr vect_tc prepr_tc _
 
     to_repr prod expr
       = do
-          (vars, body) <- prod_alt prod
+          (vars, body) <- to_unboxed prod
           return $ Case expr (mkWildId (exprType expr)) res_ty
                    [(DataAlt con, vars, body)]
 
-    prod_alt (ProdRepr { prod_components = tys
-                       , prod_data_con   = data_con })
+    to_unboxed (ProdRepr { prod_components = tys
+                         , prod_data_con   = data_con })
       = do
           vars <- mapM (newLocalVar FSLIT("r")) tys
           return (vars, mkConApp data_con (map Type tys ++ map Var vars))
 
-    prod_alt (IdRepr ty)
+    to_unboxed (IdRepr ty)
       = do
           var <- newLocalVar FSLIT("y") ty
           return ([var], Var var)
 
-    prod_alt (VoidRepr { void_bottom = bottom })
+    to_unboxed (VoidRepr { void_bottom = bottom })
       = return ([], bottom)
 
 
@@ -433,17 +443,17 @@ buildFromPRepr repr vect_tc prepr_tc _
               expr
       = do
           vars   <- mapM (newLocalVar FSLIT("x")) (map reprType prods)
-          bodies <- sequence . zipWith3 from_prod prods cons
+          bodies <- sequence . zipWith3 from_unboxed prods cons
                              $ map Var vars
           return . Case expr (mkWildId (reprType repr)) res_ty
                  $ zipWith3 sum_alt (tyConDataCons tycon) vars bodies
       where
         sum_alt data_con var body = (DataAlt data_con, [var], body)
 
-    from_repr repr expr = from_prod repr con expr
+    from_repr repr expr = from_unboxed repr con expr
 
-    from_prod prod@(ProdRepr { prod_components = tys
-                             , prod_data_con   = data_con })
+    from_unboxed prod@(ProdRepr { prod_components = tys
+                                , prod_data_con   = data_con })
               con
               expr
       = do
@@ -451,10 +461,10 @@ buildFromPRepr repr vect_tc prepr_tc _
           return $ Case expr (mkWildId (reprType prod)) res_ty
                    [(DataAlt data_con, vars, con `mkVarApps` vars)]
 
-    from_prod (IdRepr _) con expr
+    from_unboxed (IdRepr _) con expr
        = return $ con `App` expr
 
-    from_prod (VoidRepr {}) con expr
+    from_unboxed (VoidRepr {}) con expr
        = return con
 
 buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
@@ -692,7 +702,7 @@ buildPArrayDataCon orig_name vect_tc repr_tc
       shape_tys <- arrShapeTys repr
       repr_tys  <- arrReprTys  repr
 
-      let tys = shape_tys ++ concat repr_tys
+      let tys = shape_tys ++ repr_tys
 
       liftDs $ buildDataCon dc_name
                             False                  -- not infix
@@ -729,13 +739,12 @@ vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
                    -> VM ()
 vectDataConWorkers repr orig_tc vect_tc arr_tc
   = do
-      arr_tys <- arrReprElemTys repr
       bs <- sequence
           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
                                  rep_tys
-                                 (inits arr_tys)
-                                 (tail $ tails arr_tys)
+                                 (inits reprs)
+                                 (tail $ tails reprs)
       mapM_ (uncurry hoistBinding) bs
   where
     tyvars   = tyConTyVars vect_tc
@@ -745,17 +754,16 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
     res_ty   = mkTyConApp vect_tc var_tys
 
     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
+    reprs    = splitSumRepr repr
 
     [arr_dc] = tyConDataCons arr_tc
 
     mk_data_con con tys pre post
       = liftM2 (,) (vect_data_con con)
-                   (lift_data_con tys (concat pre)
-                                      (concat post)
-                                      (mkDataConTag con))
+                   (lift_data_con tys pre post (mkDataConTag con))
 
     vect_data_con con = return $ mkConApp con ty_args
-    lift_data_con tys pre_tys post_tys tag
+    lift_data_con tys pre_reprs post_reprs tag
       = do
           len  <- builtin liftingContext
           args <- mapM (newLocalVar FSLIT("xs"))
@@ -764,8 +772,8 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
           shape <- replicateShape repr (Var len) tag
           repr  <- mk_arr_repr (Var len) (map Var args)
 
-          pre   <- mapM emptyPA pre_tys
-          post  <- mapM emptyPA post_tys
+          pre   <- liftM concat $ mapM emptyArrRepr pre_reprs
+          post  <- liftM concat $ mapM emptyArrRepr post_reprs
 
           return . mkLams (len : args)
                  . wrapFamInstBody arr_tc var_tys