Fix vectorisation of sum type constructors
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
index 0fe93eb..ba64d3b 100644 (file)
@@ -288,26 +288,34 @@ reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys
 arrShapeTys :: Repr -> VM [Type]
 arrShapeTys (SumRepr  {})
   = do
-      uarr <- builtin uarrTyCon
-      return [intPrimTy, mkTyConApp uarr [intTy]]
+      int_arr <- builtin parrayIntPrimTyCon
+      return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
 arrShapeTys repr = return [intPrimTy]
 
 arrShapeVars :: Repr -> VM [Var]
 arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
 
-arrReprTys :: Repr -> VM [[Type]]
-arrReprTys (SumRepr { sum_components = prods })
-  = mapM arrProdTys prods
-arrReprTys prod
+replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
+replicateShape (ProdRepr {}) len _ = return [len]
+replicateShape (SumRepr {})  len tag
   = do
-      tys <- arrProdTys prod
-      return [tys]
+      rep <- builtin replicatePAIntPrimVar
+      up  <- builtin upToPAIntPrimVar
+      return [len, Var rep `mkApps` [len, tag], Var up `App` len]
 
-arrProdTys (ProdRepr { prod_components = tys })
-  = mapM mkPArrayType (mk_types tys)
-  where
-    mk_types []  = [unitTy]
-    mk_types tys = tys
+arrReprElemTys :: Repr -> [[Type]]
+arrReprElemTys (SumRepr { sum_components = prods })
+  = map arrProdElemTys prods
+arrReprElemTys prod@(ProdRepr {})
+  = [arrProdElemTys prod]
+
+arrProdElemTys (ProdRepr { prod_components = [] })
+  = [unitTy]
+arrProdElemTys (ProdRepr { prod_components = tys })
+  = tys
+
+arrReprTys :: Repr -> VM [[Type]]
+arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
 
 arrReprVars :: Repr -> VM [[Var]]
 arrReprVars repr
@@ -623,46 +631,12 @@ mkPADFun :: TyCon -> VM Var
 mkPADFun vect_tc
   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
 
-data Shape = Shape {
-               shapeReprTys    :: [Type]
-             , shapeStrictness :: [StrictnessMark]
-             , shapeLength     :: [CoreExpr] -> VM CoreExpr
-             , shapeReplicate  :: CoreExpr -> CoreExpr -> VM [CoreExpr]
-             }
-
-tyConShape :: TyCon -> VM Shape
-tyConShape vect_tc
-  | isProductTyCon vect_tc
-  = return $ Shape {
-                shapeReprTys    = [intPrimTy]
-              , shapeStrictness = [NotMarkedStrict]
-              , shapeLength     = \[len] -> return len
-              , shapeReplicate  = \len _ -> return [len]
-              }
-
-  | otherwise
-  = do
-      repr_ty <- mkPArrayType intTy   -- FIXME: we want to unbox this
-      return $ Shape {
-                 shapeReprTys    = [repr_ty]
-               , shapeStrictness = [MarkedStrict]
-               , shapeLength     = \[sel] -> lengthPA sel
-               , shapeReplicate  = \len n -> do
-                                               e <- replicatePA len n
-                                               return [e]
-               }
-
 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
                    -> VM [(Var, CoreExpr)]
 buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
   = do
-      shape <- tyConShape vect_tc
       repr  <- mkRepr vect_tc
-      sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
-                          orig_dcs
-                          vect_dcs
-                          (inits repr_tys)
-                          (tails repr_tys))
+      vectDataConWorkers repr orig_tc vect_tc arr_tc
       dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
       binds <- takeHoisted
       return $ (dfun, dict) : binds
@@ -673,47 +647,73 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
 
     repr_tys = map dataConRepArgTys vect_dcs
 
-vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
-                  -> DataCon -> DataCon -> [[Type]] -> [[Type]]
-                  -> VM ()
-vectDataConWorker shape vect_tc arr_tc arr_dc orig_dc vect_dc pre (dc_tys : post)
+vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
+                   -> VM ()
+vectDataConWorkers repr orig_tc vect_tc arr_tc
   = do
-      clo <- closedV
-           . inBind orig_worker
-           . polyAbstract tvs $ \abstract ->
-             liftM (abstract . vectorised)
-           $ buildClosures tvs [] dc_tys res_ty (liftM2 (,) mk_vect mk_lift)
-
-      worker <- cloneId mkVectOcc orig_worker (exprType clo)
-      hoistBinding worker clo
-      defGlobalVar orig_worker worker
-      return ()
+      bs <- sequence
+          . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
+          $ zipWith4 mk_data_con (tyConDataCons vect_tc)
+                                 rep_tys
+                                 (inits arr_tys)
+                                 (tail $ tails arr_tys)
+      mapM_ (uncurry hoistBinding) bs
   where
-    tvs     = tyConTyVars vect_tc
-    arg_tys = mkTyVarTys tvs
-    res_ty  = mkTyConApp vect_tc arg_tys
-
-    orig_worker = dataConWorkId orig_dc
-
-    mk_vect = return . mkConApp vect_dc $ map Type arg_tys
-    mk_lift = do
-                len     <- newLocalVar FSLIT("n") intPrimTy
-                arr_tys <- mapM mkPArrayType dc_tys
-                args    <- mapM (newLocalVar FSLIT("xs")) arr_tys
-                shapes  <- shapeReplicate shape
-                                          (Var len)
-                                          (mkDataConTag vect_dc)
-
-                empty_pre  <- mapM emptyPA (concat pre)
-                empty_post <- mapM emptyPA (concat post)
-
-                return . mkLams (len : args)
-                       . wrapFamInstBody arr_tc arg_tys
-                       . mkConApp arr_dc
-                       $ map Type arg_tys ++ shapes
-                                          ++ empty_pre
-                                          ++ map Var args
-                                          ++ empty_post
+    tyvars   = tyConTyVars vect_tc
+    var_tys  = mkTyVarTys tyvars
+    ty_args  = map Type var_tys
+
+    res_ty   = mkTyConApp vect_tc var_tys
+
+    rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
+    arr_tys  = arrReprElemTys repr
+
+    [arr_dc] = tyConDataCons arr_tc
+
+    mk_data_con con tys pre post
+      = liftM2 (,) (vect_data_con con)
+                   (lift_data_con tys (concat pre)
+                                      (concat post)
+                                      (mkDataConTag con))
+
+    vect_data_con con = return $ mkConApp con ty_args
+    lift_data_con tys pre_tys post_tys tag
+      = do
+          len  <- builtin liftingContext
+          args <- mapM (newLocalVar FSLIT("xs"))
+                  =<< mapM mkPArrayType tys
+
+          shape <- replicateShape repr (Var len) tag
+          repr  <- mk_arr_repr (Var len) (map Var args)
+
+          pre   <- mapM emptyPA pre_tys
+          post  <- mapM emptyPA post_tys
+
+          return . mkLams (len : args)
+                 . wrapFamInstBody arr_tc var_tys
+                 . mkConApp arr_dc
+                 $ ty_args ++ shape ++ pre ++ repr ++ post
+
+    mk_arr_repr len []
+      = do
+          units <- replicatePA len (Var unitDataConId)
+          return [units]
+
+    mk_arr_repr len arrs = return arrs
+
+    def_worker data_con arg_tys mk_body
+      = do
+          body <- closedV
+                . inBind orig_worker
+                . polyAbstract tyvars $ \abstract ->
+                  liftM (abstract . vectorised)
+                $ buildClosures tyvars [] arg_tys res_ty mk_body
+
+          vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
+          defGlobalVar orig_worker vect_worker
+          return (vect_worker, body)
+      where
+        orig_worker = dataConWorkId data_con
 
 buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
 buildPADict repr vect_tc prepr_tc arr_tc dfun