Use n-ary sums and products for NDP's generic representation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
index d1c1dab..0f101bd 100644 (file)
@@ -39,6 +39,7 @@ import BasicTypes         ( Boxity(..) )
 import Outputable
 import FastString
 
+import Data.List             ( zipWith4 )
 import Control.Monad         ( liftM, liftM2, zipWithM_ )
 
 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
@@ -82,11 +83,13 @@ splitBinTy s name ty
 
   | otherwise = pprPanic s (ppr ty)
 
-splitCrossTy :: Type -> (Type, Type)
-splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
+splitFixedTyConApp :: TyCon -> Type -> [Type]
+splitFixedTyConApp tc ty
+  | Just (tc', tys) <- splitTyConApp_maybe ty
+  , tc == tc'
+  = tys
 
-splitPlusTy :: Type -> (Type, Type)
-splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
+  | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
 
 splitEmbedTy :: Type -> Type
 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
@@ -123,25 +126,24 @@ mkBuiltinTyConApps1 get_tc dft tys
     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
 
 mkPRepr :: [[Type]] -> VM Type
-mkPRepr [] = return unitTy
 mkPRepr tys
   = do
-      embed <- builtin embedTyCon
-      cross <- builtin crossTyCon
-      plus  <- builtin plusTyCon
+      embed_tc <- builtin embedTyCon
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-      let mk_embed ty      = mkTyConApp embed [ty]
-          mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
-          mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
+      let mk_sum []   = unitTy
+          mk_sum [ty] = ty
+          mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
 
-          mk_tup   []      = unitTy
-          mk_tup   tys     = foldr1 mk_cross tys
+          mk_prod []   = unitTy
+          mk_prod [ty] = ty
+          mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
 
-          mk_sum   []      = unitTy
-          mk_sum   tys     = foldr1 mk_plus  tys
+          mk_embed ty = mkTyConApp embed_tc [ty]
 
       return . mk_sum
-             . map (mk_tup . map mk_embed)
+             . map (mk_prod . map mk_embed)
              $ tys
 
 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
@@ -149,79 +151,73 @@ mkToPRepr ess
   = do
       embed_tc <- builtin embedTyCon
       embed_dc <- builtin embedDataCon
-      cross_tc <- builtin crossTyCon
-      cross_dc <- builtin crossDataCon
-      plus_tc  <- builtin plusTyCon
-      left_dc  <- builtin leftDataCon
-      right_dc <- builtin rightDataCon
-
-      let mk_embed expr
-            = (mkConApp   embed_dc [Type ty, expr],
-               mkTyConApp embed_tc [ty])
-            where ty = exprType expr
-
-          mk_cross (expr1, ty1) (expr2, ty2)
-            = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
-               mkTyConApp cross_tc [ty1, ty2])
-
-          mk_tup [] = (Var unitDataConId, unitTy)
-          mk_tup es = foldr1 mk_cross es
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-          mk_sum []           = ([Var unitDataConId], unitTy)
+      let mk_sum [] = ([Var unitDataConId], unitTy)
           mk_sum [(expr, ty)] = ([expr], ty)
-          mk_sum ((expr, lty) : es)
-            = let (alts, rty) = mk_sum es
-              in
-              (mkConApp left_dc [Type lty, Type rty, expr]
-                 : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
-               mkTyConApp plus_tc [lty, rty])
+          mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
+                       mkTyConApp sum_tc tys)
+            where
+              (exprs, tys)   = unzip es
+              sum_tc         = sum_tcs (length es)
+              mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
+
+          mk_prod [] = (Var unitDataConId, unitTy)
+          mk_prod [(expr, ty)] = (expr, ty)
+          mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
+                        mkTyConApp prod_tc tys)
+            where
+              (exprs, tys) = unzip es
+              prod_tc      = prod_tcs (length es)
+              [prod_dc]    = tyConDataCons prod_tc
+
+          mk_embed expr = (mkConApp embed_dc [Type ty, expr],
+                           mkTyConApp embed_tc [ty])
+            where ty = exprType expr
 
-      return . mk_sum $ map (mk_tup . map mk_embed) ess
+      return . mk_sum $ map (mk_prod . map mk_embed) ess
 
 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
 mkFromPRepr scrut res_ty alts
   = do
       embed_dc <- builtin embedDataCon
-      cross_dc <- builtin crossDataCon
-      left_dc  <- builtin leftDataCon
-      right_dc <- builtin rightDataCon
-      pa_tc    <- builtin paTyCon
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-      let un_embed expr ty var res
-            = Case expr (mkWildId ty) res_ty
-                   [(DataAlt embed_dc, [var], res)]
-
-          un_cross expr ty var1 var2 res
-            = Case expr (mkWildId ty) res_ty
-                [(DataAlt cross_dc, [var1, var2], res)]
-
-          un_tup expr ty []    res = return res
-          un_tup expr ty [var] res = return $ un_embed expr ty var res
-          un_tup expr ty (var : vars) res
+      let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
+          un_sum expr ty bs
             = do
-                lv <- newLocalVar FSLIT("x") lty
-                rv <- newLocalVar FSLIT("y") rty
-                liftM (un_cross expr ty lv rv
-                      . un_embed (Var lv) lty var)
-                      (un_tup (Var rv) rty vars res)
+                ps     <- mapM (newLocalVar FSLIT("p")) tys
+                bodies <- sequence
+                        $ zipWith4 un_prod (map Var ps) tys vars rs
+                return . Case expr (mkWildId ty) res_ty
+                       $ zipWith3 mk_alt sum_dcs ps bodies
             where
-              (lty, rty) = splitCrossTy ty
+              (vars, rs) = unzip bs
+              tys        = splitFixedTyConApp sum_tc ty
+              sum_tc     = sum_tcs $ length bs
+              sum_dcs    = tyConDataCons sum_tc
 
-          un_plus expr ty var1 var2 res1 res2
-            = Case expr (mkWildId ty) res_ty
-                [(DataAlt left_dc,  [var1], res1),
-                 (DataAlt right_dc, [var2], res2)]
+              mk_alt dc p body = (DataAlt dc, [p], body)
 
-          un_sum expr ty [(vars, res)] = un_tup expr ty vars res
-          un_sum expr ty ((vars, res) : alts)
+          un_prod expr ty []    r = return r
+          un_prod expr ty [var] r = return $ un_embed expr ty var r
+          un_prod expr ty vars  r
             = do
-                lv <- newLocalVar FSLIT("l") lty
-                rv <- newLocalVar FSLIT("r") rty
-                liftM2 (un_plus expr ty lv rv)
-                         (un_tup (Var lv) lty vars res)
-                         (un_sum (Var rv) rty alts)
+                xs <- mapM (newLocalVar FSLIT("x")) tys
+                let body = foldr (\(e,t,v) r -> un_embed e t v r) r
+                         $ zip3 (map Var xs) tys vars
+                return $ Case expr (mkWildId ty) res_ty
+                         [(DataAlt prod_dc, xs, body)]
             where
-              (lty, rty) = splitPlusTy ty
+              tys       = splitFixedTyConApp prod_tc ty
+              prod_tc   = prod_tcs $ length vars
+              [prod_dc] = tyConDataCons prod_tc
+
+          un_embed expr ty var r
+            = Case expr (mkWildId ty) res_ty
+                [(DataAlt embed_dc, [var], r)]
 
       un_sum scrut (exprType scrut) alts