Use n-ary sums and products for NDP's generic representation
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Thu, 23 Aug 2007 06:09:45 +0000 (06:09 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Thu, 23 Aug 2007 06:09:45 +0000 (06:09 +0000)
Originally, we wanted to only use binary ones, at least initially. But this
would a lot of fiddling with selectors when converting to/from generic
array representations. This is both inefficient and hard to implement.
Instead, we will limit the arity of our sums/product representation to, say,
16 (it's 3 at the moment) and initially refuse to vectorise programs for which
this is not sufficient. This allows us to implement everything in the library.
Later, we can implement the necessary splitting.

compiler/prelude/PrelNames.lhs
compiler/vectorise/VectBuiltIn.hs
compiler/vectorise/VectMonad.hs
compiler/vectorise/VectUtils.hs

index 8de554d..9839290 100644 (file)
@@ -218,7 +218,7 @@ genericTyConNames = [crossTyConName, plusTyConName, genUnitTyConName]
 
 ndpNames :: [Name]
 ndpNames = [ parrayTyConName, paTyConName, preprTyConName, prTyConName
-           , ndpCrossTyConName, ndpPlusTyConName, embedTyConName
+           , embedTyConName
            , closureTyConName
            , mkClosureName, applyClosureName
            , mkClosurePName, applyClosurePName
@@ -276,6 +276,7 @@ rANDOM              = mkBaseModule FSLIT("System.Random")
 gLA_EXTS       = mkBaseModule FSLIT("GHC.Exts")
 
 nDP_PARRAY      = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray")
+nDP_REPR        = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Repr")
 nDP_UTILS       = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Utils")
 nDP_CLOSURE     = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Closure")
 nDP_INSTANCES   = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Instances")
@@ -697,9 +698,7 @@ parrayTyConName     = tcQual   nDP_PARRAY FSLIT("PArray") parrayTyConKey
 paTyConName         = tcQual   nDP_PARRAY FSLIT("PA")     paTyConKey
 preprTyConName      = tcQual   nDP_PARRAY FSLIT("PRepr")  preprTyConKey
 prTyConName         = clsQual  nDP_PARRAY FSLIT("PR")     prTyConKey
-ndpCrossTyConName   = tcQual   nDP_PARRAY FSLIT(":*:")    ndpCrossTyConKey
-ndpPlusTyConName    = tcQual   nDP_PARRAY FSLIT(":+:")    ndpPlusTyConKey
-embedTyConName      = tcQual   nDP_PARRAY FSLIT("Embed")  embedTyConKey
+embedTyConName      = tcQual   nDP_REPR   FSLIT("Embed")  embedTyConKey
 lengthPAName        = varQual  nDP_PARRAY FSLIT("lengthPA")    lengthPAIdKey
 replicatePAName     = varQual  nDP_PARRAY FSLIT("replicatePA") replicatePAIdKey
 emptyPAName         = varQual  nDP_PARRAY FSLIT("emptyPA") emptyPAIdKey
@@ -895,9 +894,7 @@ closureTyConKey                         = mkPreludeTyConUnique 136
 paTyConKey                              = mkPreludeTyConUnique 137
 preprTyConKey                           = mkPreludeTyConUnique 138
 embedTyConKey                           = mkPreludeTyConUnique 139
-ndpCrossTyConKey                        = mkPreludeTyConUnique 140
-ndpPlusTyConKey                         = mkPreludeTyConUnique 141
-prTyConKey                              = mkPreludeTyConUnique 142
+prTyConKey                              = mkPreludeTyConUnique 140
 
 
 ---------------- Template Haskell -------------------
index e6c65ac..0afef5b 100644 (file)
@@ -1,5 +1,5 @@
 module VectBuiltIn (
-  Builtins(..),
+  Builtins(..), sumTyCon, prodTyCon,
   initBuiltins, initBuiltinTyCons, initBuiltinPAs, initBuiltinPRs
 ) where
 
@@ -14,7 +14,7 @@ import TyCon           ( TyCon, tyConName, tyConDataCons )
 import Var             ( Var )
 import Id              ( mkSysLocal )
 import Name            ( Name )
-import OccName         ( mkVarOccFS )
+import OccName         ( mkVarOccFS, mkOccNameFS, tcName )
 
 import TypeRep         ( funTyCon )
 import TysPrim         ( intPrimTy )
@@ -23,9 +23,17 @@ import PrelNames
 import BasicTypes      ( Boxity(..) )
 
 import FastString
+import Outputable
 
+import Data.Array
 import Control.Monad   ( liftM, zipWithM )
 
+mAX_NDP_PROD :: Int
+mAX_NDP_PROD = 3
+
+mAX_NDP_SUM :: Int
+mAX_NDP_SUM = 3
+
 data Builtins = Builtins {
                   parrayTyCon      :: TyCon
                 , paTyCon          :: TyCon
@@ -35,11 +43,7 @@ data Builtins = Builtins {
                 , prDataCon        :: DataCon
                 , embedTyCon       :: TyCon
                 , embedDataCon     :: DataCon
-                , crossTyCon       :: TyCon
-                , crossDataCon     :: DataCon
-                , plusTyCon        :: TyCon
-                , leftDataCon      :: DataCon
-                , rightDataCon     :: DataCon
+                , sumTyCons        :: Array Int TyCon
                 , closureTyCon     :: TyCon
                 , mkClosureVar     :: Var
                 , applyClosureVar  :: Var
@@ -54,6 +58,17 @@ data Builtins = Builtins {
                 , liftingContext   :: Var
                 }
 
+sumTyCon :: Int -> Builtins -> TyCon
+sumTyCon n bi
+  | n >= 2 && n <= mAX_NDP_SUM = sumTyCons bi ! n
+  | otherwise = pprPanic "sumTyCon" (ppr n)
+
+prodTyCon :: Int -> Builtins -> TyCon
+prodTyCon n bi
+  | n >= 2 && n <= mAX_NDP_PROD = tupleTyCon Boxed n
+  | otherwise = pprPanic "prodTyCon" (ppr n)
+
+
 initBuiltins :: DsM Builtins
 initBuiltins
   = do
@@ -65,12 +80,13 @@ initBuiltins
       let [prDataCon] = tyConDataCons prTyCon
       embedTyCon   <- dsLookupTyCon embedTyConName
       let [embedDataCon] = tyConDataCons embedTyCon
-      crossTyCon   <- dsLookupTyCon ndpCrossTyConName
-      let [crossDataCon] = tyConDataCons crossTyCon
-      plusTyCon    <- dsLookupTyCon ndpPlusTyConName
-      let [leftDataCon, rightDataCon] = tyConDataCons plusTyCon
       closureTyCon <- dsLookupTyCon closureTyConName
 
+      sum_tcs <- mapM (lookupExternalTyCon nDP_REPR)
+                      [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]]
+
+      let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs
+
       mkClosureVar     <- dsLookupGlobalId mkClosureName
       applyClosureVar  <- dsLookupGlobalId applyClosureName
       mkClosurePVar    <- dsLookupGlobalId mkClosurePName
@@ -94,11 +110,7 @@ initBuiltins
                , prDataCon        = prDataCon
                , embedTyCon       = embedTyCon
                , embedDataCon     = embedDataCon
-               , crossTyCon       = crossTyCon
-               , crossDataCon     = crossDataCon
-               , plusTyCon        = plusTyCon
-               , leftDataCon      = leftDataCon
-               , rightDataCon     = rightDataCon
+               , sumTyCons        = sumTyCons
                , closureTyCon     = closureTyCon
                , mkClosureVar     = mkClosureVar
                , applyClosureVar  = applyClosureVar
@@ -137,7 +149,7 @@ initBuiltinPAs = initBuiltinDicts builtinPAs
 builtinPAs :: [(Name, Module, FastString)]
 builtinPAs = [
                mk closureTyConName  nDP_CLOSURE   FSLIT("dPA_Clo")
-             , mk unitTyConName     nDP_PARRAY    FSLIT("dPA_Unit")
+             , mk unitTyConName     nDP_INSTANCES FSLIT("dPA_Unit")
 
              , mk intTyConName      nDP_INSTANCES FSLIT("dPA_Int")
              ]
@@ -150,25 +162,37 @@ builtinPAs = [
                   nDP_INSTANCES
                   (mkFastString $ "dPA_" ++ show n)
 
-initBuiltinPRs = initBuiltinDicts builtinPRs
+initBuiltinPRs = initBuiltinDicts . builtinPRs
 
-builtinPRs :: [(Name, Module, FastString)]
-builtinPRs = [
-               mk (tyConName unitTyCon) nDP_PARRAY    FSLIT("dPR_Unit")
-             , mk ndpCrossTyConName     nDP_PARRAY    FSLIT("dPR_Cross")
-             , mk ndpPlusTyConName      nDP_PARRAY    FSLIT("dPR_Plus")
-             , mk embedTyConName        nDP_PARRAY    FSLIT("dPR_Embed")
-             , mk closureTyConName      nDP_CLOSURE   FSLIT("dPR_Clo")
+builtinPRs :: Builtins -> [(Name, Module, FastString)]
+builtinPRs bi =
+  [
+    mk (tyConName unitTyCon) nDP_REPR      FSLIT("dPR_Unit")
+  , mk embedTyConName        nDP_REPR      FSLIT("dPR_Embed")
+  , mk closureTyConName      nDP_CLOSURE   FSLIT("dPR_Clo")
 
-               -- temporary
-             , mk intTyConName          nDP_INSTANCES FSLIT("dPR_Int")
-             ]
+    -- temporary
+  , mk intTyConName          nDP_INSTANCES FSLIT("dPR_Int")
+  ]
+
+  ++ map mk_sum  [2..mAX_NDP_SUM]
+  ++ map mk_prod [2..mAX_NDP_PROD]
   where
     mk name mod fs = (name, mod, fs)
 
+    mk_sum n = (tyConName $ sumTyCon n bi, nDP_REPR,
+                mkFastString ("dPR_Sum" ++ show n))
+
+    mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR,
+                 mkFastString ("dPR_" ++ show n))
+
 lookupExternalVar :: Module -> FastString -> DsM Var
 lookupExternalVar mod fs
   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
 
+lookupExternalTyCon :: Module -> FastString -> DsM TyCon
+lookupExternalTyCon mod fs
+  = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs)
+
 unitTyConName = tyConName unitTyCon
 
index 6bc2f4d..320d192 100644 (file)
@@ -7,8 +7,8 @@ module VectMonad (
   cloneName, cloneId,
   newExportedVar, newLocalVar, newDummyVar, newTyVar,
   
-  Builtins(..),
-  builtin,
+  Builtins(..), sumTyCon, prodTyCon,
+  builtin, builtins,
 
   GlobalEnv(..),
   setFamInstEnv,
@@ -240,6 +240,9 @@ liftDs p = VM $ \bi genv lenv -> do { x <- p; return (Yes genv lenv x) }
 builtin :: (Builtins -> a) -> VM a
 builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))
 
+builtins :: (a -> Builtins -> b) -> VM (a -> b)
+builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))
+
 readGEnv :: (GlobalEnv -> a) -> VM a
 readGEnv f = VM $ \bi genv lenv -> return (Yes genv lenv (f genv))
 
@@ -454,7 +457,7 @@ initV hsc_env guts info p
         builtins       <- initBuiltins
         builtin_tycons <- initBuiltinTyCons
         builtin_pas    <- initBuiltinPAs
-        builtin_prs    <- initBuiltinPRs
+        builtin_prs    <- initBuiltinPRs builtins
 
         eps <- ioToIOEnv $ hscEPS hsc_env
         let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
index d1c1dab..0f101bd 100644 (file)
@@ -39,6 +39,7 @@ import BasicTypes         ( Boxity(..) )
 import Outputable
 import FastString
 
+import Data.List             ( zipWith4 )
 import Control.Monad         ( liftM, liftM2, zipWithM_ )
 
 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
@@ -82,11 +83,13 @@ splitBinTy s name ty
 
   | otherwise = pprPanic s (ppr ty)
 
-splitCrossTy :: Type -> (Type, Type)
-splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
+splitFixedTyConApp :: TyCon -> Type -> [Type]
+splitFixedTyConApp tc ty
+  | Just (tc', tys) <- splitTyConApp_maybe ty
+  , tc == tc'
+  = tys
 
-splitPlusTy :: Type -> (Type, Type)
-splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
+  | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
 
 splitEmbedTy :: Type -> Type
 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
@@ -123,25 +126,24 @@ mkBuiltinTyConApps1 get_tc dft tys
     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
 
 mkPRepr :: [[Type]] -> VM Type
-mkPRepr [] = return unitTy
 mkPRepr tys
   = do
-      embed <- builtin embedTyCon
-      cross <- builtin crossTyCon
-      plus  <- builtin plusTyCon
+      embed_tc <- builtin embedTyCon
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-      let mk_embed ty      = mkTyConApp embed [ty]
-          mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
-          mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
+      let mk_sum []   = unitTy
+          mk_sum [ty] = ty
+          mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
 
-          mk_tup   []      = unitTy
-          mk_tup   tys     = foldr1 mk_cross tys
+          mk_prod []   = unitTy
+          mk_prod [ty] = ty
+          mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
 
-          mk_sum   []      = unitTy
-          mk_sum   tys     = foldr1 mk_plus  tys
+          mk_embed ty = mkTyConApp embed_tc [ty]
 
       return . mk_sum
-             . map (mk_tup . map mk_embed)
+             . map (mk_prod . map mk_embed)
              $ tys
 
 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
@@ -149,79 +151,73 @@ mkToPRepr ess
   = do
       embed_tc <- builtin embedTyCon
       embed_dc <- builtin embedDataCon
-      cross_tc <- builtin crossTyCon
-      cross_dc <- builtin crossDataCon
-      plus_tc  <- builtin plusTyCon
-      left_dc  <- builtin leftDataCon
-      right_dc <- builtin rightDataCon
-
-      let mk_embed expr
-            = (mkConApp   embed_dc [Type ty, expr],
-               mkTyConApp embed_tc [ty])
-            where ty = exprType expr
-
-          mk_cross (expr1, ty1) (expr2, ty2)
-            = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
-               mkTyConApp cross_tc [ty1, ty2])
-
-          mk_tup [] = (Var unitDataConId, unitTy)
-          mk_tup es = foldr1 mk_cross es
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-          mk_sum []           = ([Var unitDataConId], unitTy)
+      let mk_sum [] = ([Var unitDataConId], unitTy)
           mk_sum [(expr, ty)] = ([expr], ty)
-          mk_sum ((expr, lty) : es)
-            = let (alts, rty) = mk_sum es
-              in
-              (mkConApp left_dc [Type lty, Type rty, expr]
-                 : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
-               mkTyConApp plus_tc [lty, rty])
+          mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
+                       mkTyConApp sum_tc tys)
+            where
+              (exprs, tys)   = unzip es
+              sum_tc         = sum_tcs (length es)
+              mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
+
+          mk_prod [] = (Var unitDataConId, unitTy)
+          mk_prod [(expr, ty)] = (expr, ty)
+          mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
+                        mkTyConApp prod_tc tys)
+            where
+              (exprs, tys) = unzip es
+              prod_tc      = prod_tcs (length es)
+              [prod_dc]    = tyConDataCons prod_tc
+
+          mk_embed expr = (mkConApp embed_dc [Type ty, expr],
+                           mkTyConApp embed_tc [ty])
+            where ty = exprType expr
 
-      return . mk_sum $ map (mk_tup . map mk_embed) ess
+      return . mk_sum $ map (mk_prod . map mk_embed) ess
 
 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
 mkFromPRepr scrut res_ty alts
   = do
       embed_dc <- builtin embedDataCon
-      cross_dc <- builtin crossDataCon
-      left_dc  <- builtin leftDataCon
-      right_dc <- builtin rightDataCon
-      pa_tc    <- builtin paTyCon
+      sum_tcs  <- builtins sumTyCon
+      prod_tcs <- builtins prodTyCon
 
-      let un_embed expr ty var res
-            = Case expr (mkWildId ty) res_ty
-                   [(DataAlt embed_dc, [var], res)]
-
-          un_cross expr ty var1 var2 res
-            = Case expr (mkWildId ty) res_ty
-                [(DataAlt cross_dc, [var1, var2], res)]
-
-          un_tup expr ty []    res = return res
-          un_tup expr ty [var] res = return $ un_embed expr ty var res
-          un_tup expr ty (var : vars) res
+      let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
+          un_sum expr ty bs
             = do
-                lv <- newLocalVar FSLIT("x") lty
-                rv <- newLocalVar FSLIT("y") rty
-                liftM (un_cross expr ty lv rv
-                      . un_embed (Var lv) lty var)
-                      (un_tup (Var rv) rty vars res)
+                ps     <- mapM (newLocalVar FSLIT("p")) tys
+                bodies <- sequence
+                        $ zipWith4 un_prod (map Var ps) tys vars rs
+                return . Case expr (mkWildId ty) res_ty
+                       $ zipWith3 mk_alt sum_dcs ps bodies
             where
-              (lty, rty) = splitCrossTy ty
+              (vars, rs) = unzip bs
+              tys        = splitFixedTyConApp sum_tc ty
+              sum_tc     = sum_tcs $ length bs
+              sum_dcs    = tyConDataCons sum_tc
 
-          un_plus expr ty var1 var2 res1 res2
-            = Case expr (mkWildId ty) res_ty
-                [(DataAlt left_dc,  [var1], res1),
-                 (DataAlt right_dc, [var2], res2)]
+              mk_alt dc p body = (DataAlt dc, [p], body)
 
-          un_sum expr ty [(vars, res)] = un_tup expr ty vars res
-          un_sum expr ty ((vars, res) : alts)
+          un_prod expr ty []    r = return r
+          un_prod expr ty [var] r = return $ un_embed expr ty var r
+          un_prod expr ty vars  r
             = do
-                lv <- newLocalVar FSLIT("l") lty
-                rv <- newLocalVar FSLIT("r") rty
-                liftM2 (un_plus expr ty lv rv)
-                         (un_tup (Var lv) lty vars res)
-                         (un_sum (Var rv) rty alts)
+                xs <- mapM (newLocalVar FSLIT("x")) tys
+                let body = foldr (\(e,t,v) r -> un_embed e t v r) r
+                         $ zip3 (map Var xs) tys vars
+                return $ Case expr (mkWildId ty) res_ty
+                         [(DataAlt prod_dc, xs, body)]
             where
-              (lty, rty) = splitPlusTy ty
+              tys       = splitFixedTyConApp prod_tc ty
+              prod_tc   = prod_tcs $ length vars
+              [prod_dc] = tyConDataCons prod_tc
+
+          un_embed expr ty var r
+            = Case expr (mkWildId ty) res_ty
+                [(DataAlt embed_dc, [var], r)]
 
       un_sum scrut (exprType scrut) alts