This BIG PATCH contains most of the work for the New Coercion Representation
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index 569057e..4676e18 100644 (file)
@@ -1,15 +1,23 @@
 
 -- | Vectorisation of expressions.
-module Vectorise.Exp
-       (vectPolyExpr)
-where
-import Vectorise.Utils
+module Vectorise.Exp (
+
+  -- Vectorise a polymorphic expression
+  vectPolyExpr, 
+  
+  -- Vectorise a scalar expression of functional type
+  vectScalarFun
+) where
+
+#include "HsVersions.h"
+
 import Vectorise.Type.Type
 import Vectorise.Var
 import Vectorise.Vect
 import Vectorise.Env
 import Vectorise.Monad
 import Vectorise.Builtins
+import Vectorise.Utils
 
 import CoreSyn
 import CoreUtils
@@ -148,134 +156,142 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno
 
 -- | Vectorise an expression with an outer lambda abstraction.
 --
-vectFnExpr :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
-           -> Bool             -- ^ Whether the binding is a loop breaker.
-           -> [Var]
-           -> CoreExprWithFVs  -- ^ Expression to vectorise. Must have an outer `AnnLam`.
+vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
+                               --   be inlined
+           -> Bool             -- ^ Whether the binding is a loop breaker
+           -> [Var]            -- ^ Names of function in same recursive binding group
+           -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
            -> VM (Inline, Bool, VExpr)
-vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
-                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
-                `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
-  where
-    (bs,body) = collectAnnValBinders e
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
+  | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+                `orElseV` 
+                mark inlineMe False (vectLam inline loop_breaker expr)
 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
 
 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 
-
--- | Vectorise a function where are the args have scalar type,
---   that is Int, Float, Double etc.
-vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function
-       -> [Var]
-       -> CoreExpr     -- ^ Function body.
-       -> VM VExpr
-       
-vectScalarLam args recFns body
- = do scalars' <- globalScalars
-      let scalars = unionVarSet (mkVarSet recFns) scalars'
-      onlyIfV (all is_prim_ty arg_tys
-               && is_prim_ty res_ty
-               && is_scalar (extendVarSetList scalars args) body
-               && uses scalars body)
-        $ do
-            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
-            zipf    <- zipScalars arg_tys res_ty
-            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
-                                                (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo DontInline
-            lclo    <- liftPD (Var clo_var)
-            return (Var clo_var, lclo)
+-- |Vectorise an expression of functional type, where all arguments and the result are of scalar
+-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
+-- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
+-- transformation; instead, they can be lifted by application of a member of the zipWith family
+-- (i.e., 'map', 'zipWith', zipWith3', etc.)
+--
+vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
+              -> [Var]      -- ^ Functions names in same recursive binding group
+              -> CoreExpr   -- ^ Expression to be vectorised
+              -> VM VExpr
+vectScalarFun forceScalar recFns expr
+ = do { gscalars <- globalScalars
+      ; let scalars = gscalars `extendVarSetList` recFns
+            (arg_tys, res_ty) = splitFunTys (exprType expr)
+      ; MASSERT( not $ null arg_tys )
+      ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
+                 ||
+                 all is_prim_ty arg_tys         -- check whether the function is scalar
+                  && is_prim_ty res_ty
+                  && is_scalar scalars expr
+                  && uses scalars expr)
+        $ mkScalarFun arg_tys res_ty expr
+      }
   where
-    arg_tys = map idType args
-    res_ty  = exprType body
-
+    -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
     is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
-
         | otherwise = False
-    
-    cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
-         
-    maybe_parr_ty ty = maybe_parr_ty' [] ty
-      
-    maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
-    maybe_parr_ty' alreadySeen ty
-       | isPArrTyCon tycon     = True
-       | isPrimTyCon tycon     = False
-       | isAbstractTyCon tycon = True
-       | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
-       | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
-                             hasParrDataCon alreadySeen tycon
-       | otherwise = True
-       where
-         Just (tycon, args) = splitTyConApp_maybe ty 
-         
-         
-         hasParrDataCon alreadySeen tycon
-           | tycon `elem` alreadySeen = False  
-           | otherwise                =  
-               any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
-         
-    -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
-    -- an external (non data constructor) variable is used, or anonymous data constructor      
-    is_scalar vs e@(Var v) 
-      | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
-      | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
-    is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
-
-    is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
-                               is_scalar vs e1 && is_scalar vs e2    
-    is_scalar vs e@(Let (NonRec b letExpr) body) 
-                             = cantbe_parr_expr e &&
-                               is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
-    is_scalar vs  e@(Let (Rec bnds) body) 
-                             =  let vs' = extendVarSetList vs (map fst bnds)
-                                in cantbe_parr_expr e &&  
-                                   all (is_scalar vs') (map snd bnds) && is_scalar vs' body
-    is_scalar vs e@(Case eC eId ty alts)  
-                             = let vs' = extendVarSet vs eId
-                                  in cantbe_parr_expr e && 
-                                  is_prim_ty ty &&
-                                  is_scalar vs' eC   &&
-                                  (all (is_scalar_alt vs') alts)
-                                    
-    is_scalar _ _            =  False
-
-    is_scalar_alt vs (_, bs, e) 
-                             = is_scalar (extendVarSetList vs bs) e
 
+    -- Checks whether an expression contain a non-scalar subexpression. 
+    --
+    -- Precodition: The variables in the first argument are scalar.
+    --
+    -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
+    -- them to the list of scalar variables) and then check them.  If one of them turns out not to
+    -- be scalar, the entire group is regarded as not being scalar.
+    --
+    -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
+    --        data constructor as scalar.  Should be changed once scalar types are passed
+    --        through VectInfo.
+    --
+    is_scalar :: VarSet -> CoreExpr -> Bool
+    is_scalar scalars  (Var v)         = v `elemVarSet` scalars
+    is_scalar _scalars (Lit _)         = True
+    is_scalar scalars  e@(App e1 e2) 
+      | maybe_parr_ty  (exprType e)    = False
+      | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
+    is_scalar scalars  (Lam var body)  
+      | maybe_parr_ty  (varType var)   = False
+      | otherwise                      = is_scalar (scalars `extendVarSet` var) body
+    is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
+      where
+        (bindsAreScalar, scalars') = is_scalar_bind scalars bind
+    is_scalar scalars  (Case e var ty alts)
+      | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+      | otherwise                      = False
+      where
+        scalars' = scalars `extendVarSet` var
+    is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
+    is_scalar scalars  (Note _ e   )   = is_scalar scalars e
+    is_scalar _scalars (Type {})       = True
+    is_scalar _scalars (Coercion {})   = True
+
+    -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
+    is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
+    is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
+      where
+        (vars, es) = unzip bnds
+        scalars'   = scalars `extendVarSetList` vars
+
+    is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+
+    -- Checks whether the type might be a parallel array type.  In particular, if the outermost
+    -- constructor is a type family, we conservatively assume that it may be a parallel array type.
+    maybe_parr_ty :: Type -> Bool
+    maybe_parr_ty ty 
+      | Just ty'        <- coreView ty            = maybe_parr_ty ty'
+      | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
+    maybe_parr_ty _                               = False
+
+    -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
+    --        is called by some other function that is otherwise scalar, it would be very bad
+    --        that just this call to the identity makes it not be scalar.
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
     -- (\n# x -> x) which is what we want.
-    uses funs (Var v)     = v `elemVarSet` funs 
-    uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Var v)       = v `elemVarSet` funs 
+    uses funs (App e1 e2)   = uses funs e1 || uses funs e2
+    uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
     uses funs (Let (NonRec _b letExpr) body) 
-                          = uses funs letExpr || uses funs  body
+                            = uses funs letExpr || uses funs  body
     uses funs (Case e _eId _ty alts) 
-                          = uses funs e || any (uses_alt funs) alts
-    uses _ _              = False
+                            = uses funs e || any (uses_alt funs) alts
+    uses _ _                = False
 
-    uses_alt funs (_, _bs, e)   
-                          = uses funs e 
+    uses_alt funs (_, _bs, e) = uses funs e 
+
+mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
+mkScalarFun arg_tys res_ty expr
+  = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
+       ; zipf    <- zipScalars arg_tys res_ty
+       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
+       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
+       ; lclo    <- liftPD (Var clo_var)
+       ; return (Var clo_var, lclo)
+       }
 
 -- | Vectorise a lambda abstraction.
-vectLam 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> VarSet               -- ^ The free variables in the body.
-       -> [Var]                -- ^ Binding variables.
-       -> CoreExprWithFVs      -- ^ Body of abstraction.
-       -> VM VExpr
-
-vectLam inline loop_breaker fvs bs body
- = do tyvars    <- localTyVars
+--
+vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
+        -> Bool             -- ^ Whether the binding is a loop breaker.
+        -> CoreExprWithFVs  -- ^ Body of abstraction.
+        -> VM VExpr
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+ = do let (bs, body) = collectAnnValBinders expr
+      tyvars    <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
                    unzip [(var, vv) | var <- varSetElems fvs
                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
@@ -305,6 +321,7 @@ vectLam inline loop_breaker fvs bs body
                          (LitAlt (mkMachInt 0), [], empty)])
 
       | otherwise = return (ve, le)
+vectLam _ _ _ = panic "vectLam"
  
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr