This BIG PATCH contains most of the work for the New Coercion Representation
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index 079e826..4676e18 100644 (file)
@@ -1,15 +1,23 @@
 
 -- | Vectorisation of expressions.
-module Vectorise.Exp
-       (vectPolyExpr)
-where
-import Vectorise.Utils
+module Vectorise.Exp (
+
+  -- Vectorise a polymorphic expression
+  vectPolyExpr, 
+  
+  -- Vectorise a scalar expression of functional type
+  vectScalarFun
+) where
+
+#include "HsVersions.h"
+
 import Vectorise.Type.Type
 import Vectorise.Var
 import Vectorise.Vect
 import Vectorise.Env
 import Vectorise.Monad
 import Vectorise.Builtins
+import Vectorise.Utils
 
 import CoreSyn
 import CoreUtils
@@ -33,17 +41,15 @@ import Data.List
 
 
 -- | Vectorise a polymorphic expression.
-vectPolyExpr 
-       :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
-                                   --   binding is a loop breaker.
-       -> [Var]                        
-       -> CoreExprWithFVs
-       -> VM (Inline, Bool, VExpr)
-
+--
+vectPolyExpr :: Bool           -- ^ When vectorising the RHS of a binding, whether that
+                                             --   binding is a loop breaker.
+                  -> [Var]                     
+                  -> CoreExprWithFVs
+                  -> VM (Inline, Bool, VExpr)
 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
       return (inline, isScalarFn, vNote note expr')
-
 vectPolyExpr loop_breaker recFns expr
  = do
       arity <- polyArity tvs
@@ -148,152 +154,144 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
 
 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
 
-
 -- | Vectorise an expression with an outer lambda abstraction.
-vectFnExpr 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> [Var]
-       -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-       -> VM (Inline, Bool, VExpr)
-
-vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
-  | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
-                 onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
-                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
-                `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
-  where
-    (bs,body) = collectAnnValBinders e
-
-vectFnExpr _ _ _  e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
+--
+vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
+                               --   be inlined
+           -> Bool             -- ^ Whether the binding is a loop breaker
+           -> [Var]            -- ^ Names of function in same recursive binding group
+           -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
+           -> VM (Inline, Bool, VExpr)
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
+  | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+                `orElseV` 
+                mark inlineMe False (vectLam inline loop_breaker expr)
+vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
 
 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 
-
--- | Vectorise a function where are the args have scalar type,
---   that is Int, Float, Double etc.
-vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function
-       -> [Var]
-       -> CoreExpr     -- ^ Function body.
-       -> VM VExpr
-       
-vectScalarLam args recFns body
- = do scalars' <- globalScalars
-      let scalars = unionVarSet (mkVarSet recFns) scalars'
-{-      pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
-        pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
-        pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
-        pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
-      onlyIfV (all is_prim_ty arg_tys
-               && is_prim_ty res_ty
-               && is_scalar (extendVarSetList scalars args) body
-               && uses scalars body)
-        $ do
-            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
-            zipf    <- zipScalars arg_tys res_ty
-            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
-                                                (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo DontInline
-            lclo    <- liftPD (Var clo_var)
-            {- pprTrace "  lam is scalar" (ppr "") $ -}
-            return (Var clo_var, lclo)
+-- |Vectorise an expression of functional type, where all arguments and the result are of scalar
+-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
+-- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
+-- transformation; instead, they can be lifted by application of a member of the zipWith family
+-- (i.e., 'map', 'zipWith', zipWith3', etc.)
+--
+vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
+              -> [Var]      -- ^ Functions names in same recursive binding group
+              -> CoreExpr   -- ^ Expression to be vectorised
+              -> VM VExpr
+vectScalarFun forceScalar recFns expr
+ = do { gscalars <- globalScalars
+      ; let scalars = gscalars `extendVarSetList` recFns
+            (arg_tys, res_ty) = splitFunTys (exprType expr)
+      ; MASSERT( not $ null arg_tys )
+      ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
+                 ||
+                 all is_prim_ty arg_tys         -- check whether the function is scalar
+                  && is_prim_ty res_ty
+                  && is_scalar scalars expr
+                  && uses scalars expr)
+        $ mkScalarFun arg_tys res_ty expr
+      }
   where
-    arg_tys = map idType args
-    res_ty  = exprType body
-
+    -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
     is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
-
         | otherwise = False
-    
-    cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
-         
-    maybe_parr_ty ty = maybe_parr_ty' [] ty
-      
-    maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
-    maybe_parr_ty' alreadySeen ty
-       | isPArrTyCon tycon     = True
-       | isPrimTyCon tycon     = False
-       | isAbstractTyCon tycon = True
-       | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
-       | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $ 
-                             any (maybe_parr_ty' alreadySeen) args || 
-                             hasParrDataCon alreadySeen tycon
-       | otherwise = True
-       where
-         Just (tycon, args) = splitTyConApp_maybe ty 
-         
-         
-         hasParrDataCon alreadySeen tycon
-           | tycon `elem` alreadySeen = False  
-           | otherwise                =  
-               any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
-         
-    -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
-    -- an external (non data constructor) variable is used, or anonymous data constructor      
-    is_scalar vs e@(Var v) 
-      | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
-      | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
-    is_scalar _ e@(Lit _)    = -- pprTrace "is_scalar  Lit" (ppr e) $ 
-                               cantbe_parr_expr e  
-
-    is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar  App" (ppr e) $  
-                               cantbe_parr_expr e &&
-                               is_scalar vs e1 && is_scalar vs e2    
-    is_scalar vs e@(Let (NonRec b letExpr) body) 
-                             = -- pprTrace "is_scalar  Let" (ppr e) $  
-                               cantbe_parr_expr e &&
-                               is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
-    is_scalar vs e@(Let (Rec bnds) body) 
-                             =  let vs' = extendVarSetList vs (map fst bnds)
-                                in -- pprTrace "is_scalar  Rec" (ppr e) $  
-                                   cantbe_parr_expr e &&  
-                                   all (is_scalar vs') (map snd bnds) && is_scalar vs' body
-    is_scalar vs e@(Case eC eId ty alts)  
-                             = let vs' = extendVarSet vs eId
-                                  in -- pprTrace "is_scalar  Case" (ppr e) $ 
-                                     cantbe_parr_expr e && 
-                                  is_prim_ty ty &&
-                                  is_scalar vs' eC   &&
-                                  (all (is_scalar_alt vs') alts)
-                                    
-    is_scalar _ e            =  -- pprTrace "is_scalar  other" (ppr e) $  
-                                False
-
-    is_scalar_alt vs (_, bs, e) 
-                             = is_scalar (extendVarSetList vs bs) e
 
+    -- Checks whether an expression contain a non-scalar subexpression. 
+    --
+    -- Precodition: The variables in the first argument are scalar.
+    --
+    -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
+    -- them to the list of scalar variables) and then check them.  If one of them turns out not to
+    -- be scalar, the entire group is regarded as not being scalar.
+    --
+    -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
+    --        data constructor as scalar.  Should be changed once scalar types are passed
+    --        through VectInfo.
+    --
+    is_scalar :: VarSet -> CoreExpr -> Bool
+    is_scalar scalars  (Var v)         = v `elemVarSet` scalars
+    is_scalar _scalars (Lit _)         = True
+    is_scalar scalars  e@(App e1 e2) 
+      | maybe_parr_ty  (exprType e)    = False
+      | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
+    is_scalar scalars  (Lam var body)  
+      | maybe_parr_ty  (varType var)   = False
+      | otherwise                      = is_scalar (scalars `extendVarSet` var) body
+    is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
+      where
+        (bindsAreScalar, scalars') = is_scalar_bind scalars bind
+    is_scalar scalars  (Case e var ty alts)
+      | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+      | otherwise                      = False
+      where
+        scalars' = scalars `extendVarSet` var
+    is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
+    is_scalar scalars  (Note _ e   )   = is_scalar scalars e
+    is_scalar _scalars (Type {})       = True
+    is_scalar _scalars (Coercion {})   = True
+
+    -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
+    is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
+    is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
+      where
+        (vars, es) = unzip bnds
+        scalars'   = scalars `extendVarSetList` vars
+
+    is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+
+    -- Checks whether the type might be a parallel array type.  In particular, if the outermost
+    -- constructor is a type family, we conservatively assume that it may be a parallel array type.
+    maybe_parr_ty :: Type -> Bool
+    maybe_parr_ty ty 
+      | Just ty'        <- coreView ty            = maybe_parr_ty ty'
+      | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
+    maybe_parr_ty _                               = False
+
+    -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
+    --        is called by some other function that is otherwise scalar, it would be very bad
+    --        that just this call to the identity makes it not be scalar.
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
     -- (\n# x -> x) which is what we want.
-    uses funs (Var v)     = v `elemVarSet` funs 
-    uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Var v)       = v `elemVarSet` funs 
+    uses funs (App e1 e2)   = uses funs e1 || uses funs e2
+    uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
     uses funs (Let (NonRec _b letExpr) body) 
-                          = uses funs letExpr || uses funs  body
+                            = uses funs letExpr || uses funs  body
     uses funs (Case e _eId _ty alts) 
-                          = uses funs e || any (uses_alt funs) alts
-    uses _ _              = False
+                            = uses funs e || any (uses_alt funs) alts
+    uses _ _                = False
+
+    uses_alt funs (_, _bs, e) = uses funs e 
 
-    uses_alt funs (_, _bs, e)   
-                          = uses funs e 
+mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
+mkScalarFun arg_tys res_ty expr
+  = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
+       ; zipf    <- zipScalars arg_tys res_ty
+       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
+       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
+       ; lclo    <- liftPD (Var clo_var)
+       ; return (Var clo_var, lclo)
+       }
 
 -- | Vectorise a lambda abstraction.
-vectLam 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> VarSet               -- ^ The free variables in the body.
-       -> [Var]                -- ^ Binding variables.
-       -> CoreExprWithFVs      -- ^ Body of abstraction.
-       -> VM VExpr
-
-vectLam inline loop_breaker fvs bs body
- = do tyvars    <- localTyVars
+--
+vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
+        -> Bool             -- ^ Whether the binding is a loop breaker.
+        -> CoreExprWithFVs  -- ^ Body of abstraction.
+        -> VM VExpr
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+ = do let (bs, body) = collectAnnValBinders expr
+      tyvars    <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
                    unzip [(var, vv) | var <- varSetElems fvs
                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
@@ -323,6 +321,7 @@ vectLam inline loop_breaker fvs bs body
                          (LitAlt (mkMachInt 0), [], empty)])
 
       | otherwise = return (ve, le)
+vectLam _ _ _ = panic "vectLam"
  
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr