Completed the implementation of VECTORISE SCALAR
authorManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sat, 5 Mar 2011 12:36:25 +0000 (12:36 +0000)
committerManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sat, 5 Mar 2011 12:36:25 +0000 (12:36 +0000)
- The pragma {-# VECTORISE SCALAR foo #-} marks 'foo' as a
  scalar function for for vectorisation and generates a
  vectorised version by applying 'scalar_map' and friends.
- The set of scalar functions is not yet emitted into
  interface files.  This will be added in a subsequent
  patch via 'VectInfo'.

compiler/vectorise/Vectorise.hs
compiler/vectorise/Vectorise/Env.hs
compiler/vectorise/Vectorise/Exp.hs

index 72cca6e..ca6766a 100644 (file)
@@ -1,4 +1,4 @@
-{-# OPTIONS -fno-warn-missing-signatures #-}
+{-# OPTIONS -fno-warn-missing-signatures -fno-warn-unused-do-bind #-}
 
 module Vectorise ( vectorise )
 where
@@ -121,44 +121,53 @@ vectModule guts@(ModGuts { mg_types     = types
 --
 vectTopBind :: CoreBind -> VM CoreBind
 vectTopBind b@(NonRec var expr)
- = do
-      (inline, _, expr')       <- vectTopRhs [] var expr
-      var' <- vectTopBinder var inline expr'
+ = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it to
+          -- the vectorisation map.
+      ; (inline, isScalar, expr') <- vectTopRhs [] var expr
+      ; var' <- vectTopBinder var inline expr'
+      ; when isScalar $ 
+          addGlobalScalar var
 
-      -- Vectorising the body may create other top-level bindings.
-      hs <- takeHoisted
-
-      -- To get the same functionality as the original body we project
-      -- out its vectorised version from the closure.
-      cexpr <- tryConvert var var' expr
-
-      return . Rec $ (var, cexpr) : (var', expr') : hs
+          -- We replace the original top-level binding by a value projected from the vectorised
+          -- closure and add any newly created hoisted top-level bindings.
+      ; cexpr <- tryConvert var var' expr
+      ; hs <- takeHoisted
+      ; return . Rec $ (var, cexpr) : (var', expr') : hs
+      }
   `orElseV`
     return b
-
 vectTopBind b@(Rec bs)
- = do
-      (vars', _, exprs') 
-        <- fixV $ \ ~(_, inlines, rhss) ->
-            do vars' <- sequence [vectTopBinder var inline rhs
-                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
-               (inlines', areScalars', exprs') 
-                     <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
-               if  (and areScalars') || (length bs <= 1)
-                  then do
-                    return (vars', inlines', exprs')
-                  else do
-                    _ <- mapM deleteGlobalScalar vars
-                    (inlines'', _, exprs'')  <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
-                    return (vars', inlines'', exprs'')
+ = let (vars, exprs) = unzip bs
+   in
+   do { (vars', _, exprs', hs) <- fixV $ 
+          \ ~(_, inlines, rhss, _) ->
+            do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings and
+                   --  add them to the vectorisation map.
+               ; vars' <- sequence [vectTopBinder var inline rhs
+                                   | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
+               ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
+               ; hs <- takeHoisted
+               ; if and areScalars
+                 then      -- (1) Entire recursive group is scalar
+                           --      => add all variables to the global set of scalars
+                      do { mapM addGlobalScalar vars
+                         ; return (vars', inlines, exprs', hs)
+                         }
+                 else      -- (2) At least one binding is not scalar
+                           --     => vectorise again with empty set of local scalars
+                      do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
+                         ; hs <- takeHoisted
+                         ; return (vars', inlines, exprs', hs)
+                         }
+               }
                       
-      hs     <- takeHoisted
-      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
-      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
+          -- Replace the original top-level bindings by a values projected from the vectorised
+          -- closures and add any newly created hoisted top-level bindings to the group.
+      ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
+      ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
+      }
   `orElseV`
-    return b
-  where
-    (vars, exprs) = unzip bs
+    return b    
     
 -- | Make the vectorised version of this top level binder, and add the mapping
 --   between it and the original to the state. For some binder @foo@ the vectorised
@@ -233,22 +242,16 @@ vectTopRhs recFs var expr
   where
     rhs _globalScalar (Just (_, expr'))               -- Case (1)
       = return (inlineMe, False, expr')
-    rhs True          _vectDecl                       -- Case (2)
-      = return (inlineMe, True, scalarRHS)
-                          -- FIXME: that True is not enough to register scalarness
-    rhs False         _vectDecl                       -- Case (3)
+    rhs True          Nothing                         -- Case (2)
+      = do { expr' <- vectScalarFun True recFs expr
+           ; return (inlineMe, True, vectorised expr')
+           }
+    rhs False         Nothing                         -- Case (3)
       = do { let fvs = freeVars expr
            ; (inline, isScalar, vexpr) <- inBind var $
                                             vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs fvs
-           ; if isScalar 
-             then addGlobalScalar var
-             else deleteGlobalScalar var
            ; return (inline, isScalar, vectorised vexpr)
            }
-      
-    -- For scalar right-hand sides, we know that the original binding will remain unaltered
-    -- (hence, we can refer to it without risk of cycles) - cf, 'tryConvert'.
-    scalarRHS = panic "Vectorise.scalarRHS: not implemented yet"
 
 -- | Project out the vectorised version of a binding from some closure,
 --   or return the original body if that doesn't work or the binding is scalar. 
index 9a1fd44..5014fd6 100644 (file)
@@ -75,7 +75,8 @@ emptyLocalEnv = LocalEnv {
 --      These are things the exist at top-level.
 data GlobalEnv 
         = GlobalEnv {
-        -- | Mapping from global variables to their vectorised versions.
+        -- | Mapping from global variables to their vectorised versions — aka the /vectorisation
+        --   map/.
           global_vars           :: VarEnv Var
 
         -- | Mapping from global variables that have a vectorisation declaration to the right-hand
index 569057e..dbdf6e1 100644 (file)
@@ -1,15 +1,23 @@
 
 -- | Vectorisation of expressions.
-module Vectorise.Exp
-       (vectPolyExpr)
-where
-import Vectorise.Utils
+module Vectorise.Exp (
+
+  -- Vectorise a polymorphic expression
+  vectPolyExpr, 
+  
+  -- Vectorise a scalar expression of functional type
+  vectScalarFun
+) where
+
+#include "HsVersions.h"
+
 import Vectorise.Type.Type
 import Vectorise.Var
 import Vectorise.Vect
 import Vectorise.Env
 import Vectorise.Monad
 import Vectorise.Builtins
+import Vectorise.Utils
 
 import CoreSyn
 import CoreUtils
@@ -148,134 +156,141 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno
 
 -- | Vectorise an expression with an outer lambda abstraction.
 --
-vectFnExpr :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
-           -> Bool             -- ^ Whether the binding is a loop breaker.
-           -> [Var]
-           -> CoreExprWithFVs  -- ^ Expression to vectorise. Must have an outer `AnnLam`.
+vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
+                               --   be inlined
+           -> Bool             -- ^ Whether the binding is a loop breaker
+           -> [Var]            -- ^ Names of function in same recursive binding group
+           -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
            -> VM (Inline, Bool, VExpr)
-vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
-                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
-                `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
-  where
-    (bs,body) = collectAnnValBinders e
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
+  | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+                `orElseV` 
+                mark inlineMe False (vectLam inline loop_breaker expr)
 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
 
 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 
-
--- | Vectorise a function where are the args have scalar type,
---   that is Int, Float, Double etc.
-vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function
-       -> [Var]
-       -> CoreExpr     -- ^ Function body.
-       -> VM VExpr
-       
-vectScalarLam args recFns body
- = do scalars' <- globalScalars
-      let scalars = unionVarSet (mkVarSet recFns) scalars'
-      onlyIfV (all is_prim_ty arg_tys
-               && is_prim_ty res_ty
-               && is_scalar (extendVarSetList scalars args) body
-               && uses scalars body)
-        $ do
-            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
-            zipf    <- zipScalars arg_tys res_ty
-            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
-                                                (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo DontInline
-            lclo    <- liftPD (Var clo_var)
-            return (Var clo_var, lclo)
+-- |Vectorise an expression of functional type, where all arguments and the result are of scalar
+-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
+-- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
+-- transformation; instead, they can be lifted by application of a member of the zipWith family
+-- (i.e., 'map', 'zipWith', zipWith3', etc.)
+--
+vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
+              -> [Var]      -- ^ Functions names in same recursive binding group
+              -> CoreExpr   -- ^ Expression to be vectorised
+              -> VM VExpr
+vectScalarFun forceScalar recFns expr
+ = do { gscalars <- globalScalars
+      ; let scalars = gscalars `extendVarSetList` recFns
+            (arg_tys, res_ty) = splitFunTys (exprType expr)
+      ; MASSERT( not $ null arg_tys )
+      ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
+                 ||
+                 all is_prim_ty arg_tys         -- check whether the function is scalar
+                  && is_prim_ty res_ty
+                  && is_scalar scalars expr
+                  && uses scalars expr)
+        $ mkScalarFun arg_tys res_ty expr
+      }
   where
-    arg_tys = map idType args
-    res_ty  = exprType body
-
+    -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
     is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
-
         | otherwise = False
-    
-    cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
-         
-    maybe_parr_ty ty = maybe_parr_ty' [] ty
-      
-    maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
-    maybe_parr_ty' alreadySeen ty
-       | isPArrTyCon tycon     = True
-       | isPrimTyCon tycon     = False
-       | isAbstractTyCon tycon = True
-       | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
-       | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
-                             hasParrDataCon alreadySeen tycon
-       | otherwise = True
-       where
-         Just (tycon, args) = splitTyConApp_maybe ty 
-         
-         
-         hasParrDataCon alreadySeen tycon
-           | tycon `elem` alreadySeen = False  
-           | otherwise                =  
-               any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
-         
-    -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
-    -- an external (non data constructor) variable is used, or anonymous data constructor      
-    is_scalar vs e@(Var v) 
-      | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
-      | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
-    is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
-
-    is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
-                               is_scalar vs e1 && is_scalar vs e2    
-    is_scalar vs e@(Let (NonRec b letExpr) body) 
-                             = cantbe_parr_expr e &&
-                               is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
-    is_scalar vs  e@(Let (Rec bnds) body) 
-                             =  let vs' = extendVarSetList vs (map fst bnds)
-                                in cantbe_parr_expr e &&  
-                                   all (is_scalar vs') (map snd bnds) && is_scalar vs' body
-    is_scalar vs e@(Case eC eId ty alts)  
-                             = let vs' = extendVarSet vs eId
-                                  in cantbe_parr_expr e && 
-                                  is_prim_ty ty &&
-                                  is_scalar vs' eC   &&
-                                  (all (is_scalar_alt vs') alts)
-                                    
-    is_scalar _ _            =  False
-
-    is_scalar_alt vs (_, bs, e) 
-                             = is_scalar (extendVarSetList vs bs) e
 
+    -- Checks whether an expression contain a non-scalar subexpression. 
+    --
+    -- Precodition: The variables in the first argument are scalar.
+    --
+    -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
+    -- them to the list of scalar variables) and then check them.  If one of them turns out not to
+    -- be scalar, the entire group is regarded as not being scalar.
+    --
+    -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
+    --        data constructor as scalar.  Should be changed once scalar types are passed
+    --        through VectInfo.
+    --
+    is_scalar :: VarSet -> CoreExpr -> Bool
+    is_scalar scalars  (Var v)         = v `elemVarSet` scalars
+    is_scalar _scalars (Lit _)         = True
+    is_scalar scalars  e@(App e1 e2) 
+      | maybe_parr_ty  (exprType e)    = False
+      | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
+    is_scalar scalars  (Lam var body)  
+      | maybe_parr_ty  (varType var)   = False
+      | otherwise                      = is_scalar (scalars `extendVarSet` var) body
+    is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
+      where
+        (bindsAreScalar, scalars') = is_scalar_bind scalars bind
+    is_scalar scalars  (Case e var ty alts)
+      | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+      | otherwise                      = False
+      where
+        scalars' = scalars `extendVarSet` var
+    is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
+    is_scalar scalars  (Note _ e   )   = is_scalar scalars e
+    is_scalar _scalars (Type _)        = True
+
+    -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
+    is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
+    is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
+      where
+        (vars, es) = unzip bnds
+        scalars'   = scalars `extendVarSetList` vars
+
+    is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+
+    -- Checks whether the type might be a parallel array type.  In particular, if the outermost
+    -- constructor is a type family, we conservatively assume that it may be a parallel array type.
+    maybe_parr_ty :: Type -> Bool
+    maybe_parr_ty ty 
+      | Just ty'        <- coreView ty            = maybe_parr_ty ty'
+      | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
+    maybe_parr_ty _                               = False
+
+    -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
+    --        is called by some other function that is otherwise scalar, it would be very bad
+    --        that just this call to the identity makes it not be scalar.
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
     -- (\n# x -> x) which is what we want.
-    uses funs (Var v)     = v `elemVarSet` funs 
-    uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Var v)       = v `elemVarSet` funs 
+    uses funs (App e1 e2)   = uses funs e1 || uses funs e2
+    uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
     uses funs (Let (NonRec _b letExpr) body) 
-                          = uses funs letExpr || uses funs  body
+                            = uses funs letExpr || uses funs  body
     uses funs (Case e _eId _ty alts) 
-                          = uses funs e || any (uses_alt funs) alts
-    uses _ _              = False
+                            = uses funs e || any (uses_alt funs) alts
+    uses _ _                = False
 
-    uses_alt funs (_, _bs, e)   
-                          = uses funs e 
+    uses_alt funs (_, _bs, e) = uses funs e 
+
+mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
+mkScalarFun arg_tys res_ty expr
+  = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
+       ; zipf    <- zipScalars arg_tys res_ty
+       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
+       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
+       ; lclo    <- liftPD (Var clo_var)
+       ; return (Var clo_var, lclo)
+       }
 
 -- | Vectorise a lambda abstraction.
-vectLam 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> VarSet               -- ^ The free variables in the body.
-       -> [Var]                -- ^ Binding variables.
-       -> CoreExprWithFVs      -- ^ Body of abstraction.
-       -> VM VExpr
-
-vectLam inline loop_breaker fvs bs body
- = do tyvars    <- localTyVars
+--
+vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
+        -> Bool             -- ^ Whether the binding is a loop breaker.
+        -> CoreExprWithFVs  -- ^ Body of abstraction.
+        -> VM VExpr
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+ = do let (bs, body) = collectAnnValBinders expr
+      tyvars    <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
                    unzip [(var, vv) | var <- varSetElems fvs
                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
@@ -305,6 +320,7 @@ vectLam inline loop_breaker fvs bs body
                          (LitAlt (mkMachInt 0), [], empty)])
 
       | otherwise = return (ve, le)
+vectLam _ _ _ = panic "vectLam"
  
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr