Fix vectorisation of nullary data constructors
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
index b238199..ca5f0c8 100644 (file)
@@ -224,6 +224,22 @@ data Repr = ProdRepr {
             , sum_arr_data_con  :: DataCon
             }
 
+          | IdRepr Type
+
+          | VoidRepr {
+              void_tycon        :: TyCon
+            , void_bottom       :: CoreExpr
+            }
+
+mkVoid :: VM Repr
+mkVoid = do
+           tycon <- builtin voidTyCon
+           var   <- builtin voidVar
+           return $ VoidRepr {
+                      void_tycon  = tycon
+                    , void_bottom = Var var
+                    }
+
 mkProduct :: [Type] -> VM Repr
 mkProduct tys
   = do
@@ -243,6 +259,11 @@ mkProduct tys
   where
     arity = length tys
 
+mkSubProduct :: [Type] -> VM Repr
+mkSubProduct []   = mkVoid
+mkSubProduct [ty] = return $ IdRepr ty
+mkSubProduct tys  = mkProduct tys
+
 mkSum :: [Repr] -> VM Repr
 mkSum [repr] = return repr
 mkSum reprs
@@ -263,54 +284,62 @@ mkSum reprs
   where
     arity = length reprs
 
-reprProducts :: Repr -> [Repr]
-reprProducts (SumRepr { sum_components = rs }) = rs
-reprProducts repr                              = [repr]
-
 reprType :: Repr -> Type
 reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
   = mkTyConApp tycon tys
 reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
   = mkTyConApp tycon (map reprType reprs)
+reprType (IdRepr ty) = ty
+reprType (VoidRepr { void_tycon = tycon }) = mkTyConApp tycon []
 
 arrReprType :: Repr -> VM Type
 arrReprType = mkPArrayType . reprType
 
-reprTys :: Repr -> [[Type]]
-reprTys (SumRepr { sum_components = prods }) = map prodTys prods
-reprTys prod                                 = [prodTys prod]
-
-prodTys (ProdRepr { prod_components = tys }) = tys
-
-reprVars :: Repr -> VM [[Var]]
-reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys
-
 arrShapeTys :: Repr -> VM [Type]
 arrShapeTys (SumRepr  {})
   = do
-      uarr <- builtin uarrTyCon
-      return [intPrimTy, mkTyConApp uarr [intTy]]
-arrShapeTys repr = return [intPrimTy]
+      int_arr <- builtin parrayIntPrimTyCon
+      return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
+arrShapeTys (ProdRepr {}) = return [intPrimTy]
+arrShapeTys (IdRepr _)    = return []
+arrShapeTys (VoidRepr {}) = return [intPrimTy]
 
 arrShapeVars :: Repr -> VM [Var]
 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
 
 replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
 replicateShape (ProdRepr {}) len _ = return [len]
+replicateShape (SumRepr {})  len tag
+  = do
+      rep <- builtin replicatePAIntPrimVar
+      up  <- builtin upToPAIntPrimVar
+      return [len, Var rep `mkApps` [len, tag], Var up `App` len]
+replicateShape (IdRepr _) _ _ = return []
+replicateShape (VoidRepr {}) len _ = return [len]
 
-arrReprElemTys :: Repr -> [[Type]]
+arrReprElemTys :: Repr -> VM [[Type]]
 arrReprElemTys (SumRepr { sum_components = prods })
-  = map arrProdElemTys prods
+  = mapM arrProdElemTys prods
 arrReprElemTys prod@(ProdRepr {})
-  = [arrProdElemTys prod]
+  = do
+      tys <- arrProdElemTys prod
+      return [tys]
+arrReprElemTys (IdRepr ty) = return [[ty]]
+arrReprElemTys (VoidRepr { void_tycon = tycon })
+  = return [[mkTyConApp tycon []]]
 
 arrProdElemTys (ProdRepr { prod_components = [] })
-  = [unitTy]
+  = do
+      void <- builtin voidTyCon
+      return [mkTyConApp void []]
 arrProdElemTys (ProdRepr { prod_components = tys })
-  = tys
+  = return tys
+arrProdElemTys (IdRepr ty) = return [ty]
+arrProdElemTys (VoidRepr { void_tycon = tycon })
+  = return [mkTyConApp tycon []]
 
 arrReprTys :: Repr -> VM [[Type]]
-arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
+arrReprTys repr = mapM (mapM mkPArrayType) =<< arrReprElemTys repr
 
 arrReprVars :: Repr -> VM [[Var]]
 arrReprVars repr
@@ -318,8 +347,10 @@ arrReprVars repr
 
 mkRepr :: TyCon -> VM Repr
 mkRepr vect_tc
-  = mkSum
-  =<< mapM mkProduct (map dataConRepArgTys $ tyConDataCons vect_tc)
+  | [tys] <- rep_tys = mkProduct tys
+  | otherwise        = mkSum =<< mapM mkSubProduct rep_tys
+  where
+    rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
 
 buildPReprType :: TyCon -> VM Type
 buildPReprType = liftM reprType . mkRepr
@@ -366,6 +397,15 @@ buildToPRepr repr vect_tc prepr_tc _
           vars <- mapM (newLocalVar FSLIT("r")) tys
           return (vars, mkConApp data_con (map Type tys ++ map Var vars))
 
+    prod_alt (IdRepr ty)
+      = do
+          var <- newLocalVar FSLIT("y") ty
+          return ([var], Var var)
+
+    prod_alt (VoidRepr { void_bottom = bottom })
+      = return ([], bottom)
+
+
 buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildFromPRepr repr vect_tc prepr_tc _
   = do
@@ -405,6 +445,12 @@ buildFromPRepr repr vect_tc prepr_tc _
           return $ Case expr (mkWildId (reprType prod)) res_ty
                    [(DataAlt data_con, vars, con `mkVarApps` vars)]
 
+    from_prod (IdRepr _) con expr
+       = return $ con `App` expr
+
+    from_prod (VoidRepr {}) con expr
+       = return con
+
 buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildToArrPRepr repr vect_tc prepr_tc arr_tc
   = do
@@ -443,7 +489,7 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc
                      , sum_arr_tycon    = tycon
                      , sum_arr_data_con = data_con })
       = do
-          exprs <- zipWithM (to_prod len_var) repr_vars prods
+          exprs <- zipWithM to_prod repr_vars prods
 
           return . wrapFamInstBody tycon tys
                  . mkConApp data_con
@@ -451,16 +497,28 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc
       where
         tys = map reprType prods
 
-    to_repr [len_var] [repr_vars] prod = to_prod len_var repr_vars prod
+    to_repr [len_var]
+            [repr_vars]
+            (ProdRepr { prod_components   = tys
+                      , prod_arr_tycon    = tycon
+                      , prod_arr_data_con = data_con })
+       = return . wrapFamInstBody tycon tys
+                . mkConApp data_con
+                $ map Type tys ++ map Var (len_var : repr_vars)
 
-    to_prod len_var
-            repr_vars
+    to_prod repr_vars@(r : _)
             (ProdRepr { prod_components   = tys
                       , prod_arr_tycon    = tycon
                       , prod_arr_data_con = data_con })
-      = return . wrapFamInstBody tycon tys
-               . mkConApp data_con
-               $ map Type tys ++ map Var (len_var : repr_vars)
+      = do
+          len <- lengthPA (Var r)
+          return . wrapFamInstBody tycon tys
+                 . mkConApp data_con
+                 $ map Type tys ++ len : map Var repr_vars
+
+    to_prod [var] (IdRepr ty)   = return (Var var)
+    to_prod [var] (VoidRepr {}) = return (Var var)
+
 
 buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
 buildFromArrPRepr repr vect_tc prepr_tc arr_tc
@@ -539,7 +597,26 @@ buildFromArrPRepr repr vect_tc prepr_tc arr_tc
           return $ Case scrut (mkWildId scrut_ty) res_ty
                    [(DataAlt data_con, shape_vars ++ repr_vars, body)]
 
+    from_prod (IdRepr ty)
+              expr
+              shape_vars
+              [repr_var]
+              res_ty
+              body
+      = return $ Let (NonRec repr_var expr) body
+
+    from_prod (VoidRepr {})
+              expr
+              shape_vars
+              [repr_var]
+              res_ty
+              body
+      = return $ Let (NonRec repr_var expr) body
+
 buildPRDictRepr :: Repr -> VM CoreExpr
+buildPRDictRepr (VoidRepr { void_tycon = tycon })
+  = prDFunOfTyCon tycon
+buildPRDictRepr (IdRepr ty) = mkPR ty
 buildPRDictRepr (ProdRepr {
                    prod_components = tys
                  , prod_tycon      = tycon
@@ -626,40 +703,10 @@ mkPADFun :: TyCon -> VM Var
 mkPADFun vect_tc
   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
 
-data Shape = Shape {
-               shapeReprTys    :: [Type]
-             , shapeStrictness :: [StrictnessMark]
-             , shapeLength     :: [CoreExpr] -> VM CoreExpr
-             , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
-             }
-
-tyConShape :: TyCon -> VM Shape
-tyConShape vect_tc
-  | isProductTyCon vect_tc
-  = return $ Shape {
-                shapeReprTys    = [intPrimTy]
-              , shapeStrictness = [NotMarkedStrict]
-              , shapeLength     = \[len] -> return len
-              , shapeReplicate  = \len _ -> return [len]
-              }
-
-  | otherwise
-  = do
-      repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
-      return $ Shape {
-                 shapeReprTys    = [repr_ty]
-               , shapeStrictness = [MarkedStrict]
-               , shapeLength     = \[sel] -> lengthPA sel
-               , shapeReplicate  = \len n -> do
-                                               e <- replicatePA len n
-                                               return [e]
-               }
-
 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
                    -> VM [(Var, CoreExpr)]
 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
   = do
-      shape <- tyConShape vect_tc
       repr  <- mkRepr vect_tc
       vectDataConWorkers repr orig_tc vect_tc arr_tc
       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
@@ -676,6 +723,7 @@ vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
                    -> VM ()
 vectDataConWorkers repr orig_tc vect_tc arr_tc
   = do
+      arr_tys <- arrReprElemTys repr
       bs <- sequence
           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
@@ -691,7 +739,6 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
     res_ty   = mkTyConApp vect_tc var_tys
 
     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
-    arr_tys  = arrReprElemTys repr
 
     [arr_dc] = tyConDataCons arr_tc
 
@@ -700,7 +747,6 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
                    (lift_data_con tys (concat pre)
                                       (concat post)
                                       (mkDataConTag con))
-                   
 
     vect_data_con con = return $ mkConApp con ty_args
     lift_data_con tys pre_tys post_tys tag
@@ -708,10 +754,10 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
           len  <- builtin liftingContext
           args <- mapM (newLocalVar FSLIT("xs"))
                   =<< mapM mkPArrayType tys
-      
+
           shape <- replicateShape repr (Var len) tag
           repr  <- mk_arr_repr (Var len) (map Var args)
-      
+
           pre   <- mapM emptyPA pre_tys
           post  <- mapM emptyPA post_tys
 
@@ -741,48 +787,6 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
       where
         orig_worker = dataConWorkId data_con
 
-vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
-                  -> DataCon -> DataCon -> [[Type]] -> [[Type]]
-                  -> VM ()
-vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
-  = do
-      clo <- closedV
-           . inBind orig_worker
-           . polyAbstract tvs $ \abstract ->
-             liftM (abstract . vectorised)
-           $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
-
-      worker <- cloneId mkVectOcc orig_worker (exprType clo)
-      hoistBinding worker clo
-      defGlobalVar orig_worker worker
-      return ()
-  where
-    tvs     = tyConTyVars vect_tc
-    arg_tys = mkTyVarTys tvs
-    res_ty  = mkTyConApp vect_tc arg_tys
-
-    orig_worker = dataConWorkId orig_dc
-
-    mk_vect = return . mkConApp vect_dc $ map Type arg_tys
-    mk_lift = do
-                len     <- newLocalVar FSLIT("n") intPrimTy
-                arr_tys <- mapM mkPArrayType dc_tys
-                args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
-                shapes  <- shapeReplicate shape
-                                          (Var len)
-                                          (mkDataConTag vect_dc)
-
-                empty_pre  <- mapM emptyPA (concat pre)
-                empty_post <- mapM emptyPA (concat post)
-
-                return . mkLams (len : args)
-                       . wrapFamInstBody arr_tc arg_tys
-                       . mkConApp arr_dc
-                       $ map Type arg_tys ++ shapes
-                                          ++ empty_pre
-                                          ++ map Var args
-                                          ++ empty_post
-
 buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
 buildPADict repr vect_tc prepr_tc arr_tc dfun
   = polyAbstract tvs $ \abstract ->