This BIG PATCH contains most of the work for the New Coercion Representation
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index d35c947..4676e18 100644 (file)
@@ -1,15 +1,23 @@
 
 -- | Vectorisation of expressions.
-module Vectorise.Exp
-       (vectPolyExpr)
-where
-import VectUtils
-import VectVar
-import VectType
+module Vectorise.Exp (
+
+  -- Vectorise a polymorphic expression
+  vectPolyExpr, 
+  
+  -- Vectorise a scalar expression of functional type
+  vectScalarFun
+) where
+
+#include "HsVersions.h"
+
+import Vectorise.Type.Type
+import Vectorise.Var
 import Vectorise.Vect
 import Vectorise.Env
 import Vectorise.Monad
 import Vectorise.Builtins
+import Vectorise.Utils
 
 import CoreSyn
 import CoreUtils
@@ -22,7 +30,7 @@ import Var
 import VarEnv
 import VarSet
 import Id
-import BasicTypes
+import BasicTypes( isLoopBreaker )
 import Literal
 import TysWiredIn
 import TysPrim
@@ -33,23 +41,22 @@ import Data.List
 
 
 -- | Vectorise a polymorphic expression.
-vectPolyExpr 
-       :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
-                               --   binding is a loop breaker.
-       -> CoreExprWithFVs
-       -> VM (Inline, VExpr)
-
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, expr') <- vectPolyExpr loop_breaker expr
-      return (inline, vNote note expr')
-
-vectPolyExpr loop_breaker expr
+--
+vectPolyExpr :: Bool           -- ^ When vectorising the RHS of a binding, whether that
+                                             --   binding is a loop breaker.
+                  -> [Var]                     
+                  -> CoreExprWithFVs
+                  -> VM (Inline, Bool, VExpr)
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
+      return (inline, isScalarFn, vNote note expr')
+vectPolyExpr loop_breaker recFns expr
  = do
       arity <- polyArity tvs
       polyAbstract tvs $ \args ->
         do
-          (inline, mono') <- vectFnExpr False loop_breaker mono
-          return (addInlineArity inline arity,
+          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
+          return (addInlineArity inline arity, isScalarFn, 
                   mapVect (mkLams $ tvs ++ args) mono')
   where
     (tvs, mono) = collectAnnTypeBinders expr
@@ -111,12 +118,13 @@ vectExpr (_, AnnCase scrut bndr ty alts)
   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
   , isAlgTyCon tycon
   = vectAlgCase tycon ty_args scrut bndr ty alts
+  | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
   where
     scrut_ty = exprType (deAnnotate scrut)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
+      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -132,11 +140,11 @@ vectExpr (_, AnnLet (AnnRec bs) body)
 
     vect_rhs bndr rhs = localV
                       . inBind bndr
-                      . liftM snd
-                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+                      . liftM (\(_,_,z)->z)
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = liftM snd $ vectFnExpr True False e
+  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -144,87 +152,146 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
     (bs,body) = collectAnnValBinders e
 -}
 
-vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
-
+vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
 
 -- | Vectorise an expression with an outer lambda abstraction.
-vectFnExpr 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-       -> VM (Inline, VExpr)
-
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV (isEmptyVarSet fvs)
-                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
-                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
-  where
-    (bs,body) = collectAnnValBinders e
-
-vectFnExpr _ _ e = mark DontInline $ vectExpr e
-
-mark :: Inline -> VM a -> VM (Inline, a)
-mark b p = do { x <- p; return (b,x) }
-
-
--- | Vectorise a function where are the args have scalar type,
---   that is Int, Float, Double etc.
-vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function.
-       -> CoreExpr     -- ^ Function body.
-       -> VM VExpr
-       
-vectScalarLam args body
- = do scalars <- globalScalars
-      onlyIfV (all is_scalar_ty arg_tys
-               && is_scalar_ty res_ty
-               && is_scalar (extendVarSetList scalars args) body
-               && uses scalars body)
-        $ do
-            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
-            zipf    <- zipScalars arg_tys res_ty
-            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
-                                                (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo DontInline
-            lclo    <- liftPD (Var clo_var)
-            return (Var clo_var, lclo)
+--
+vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
+                               --   be inlined
+           -> Bool             -- ^ Whether the binding is a loop breaker
+           -> [Var]            -- ^ Names of function in same recursive binding group
+           -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
+           -> VM (Inline, Bool, VExpr)
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
+  | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+                `orElseV` 
+                mark inlineMe False (vectLam inline loop_breaker expr)
+vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
+
+mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
+mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
+
+-- |Vectorise an expression of functional type, where all arguments and the result are of scalar
+-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
+-- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
+-- transformation; instead, they can be lifted by application of a member of the zipWith family
+-- (i.e., 'map', 'zipWith', zipWith3', etc.)
+--
+vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
+              -> [Var]      -- ^ Functions names in same recursive binding group
+              -> CoreExpr   -- ^ Expression to be vectorised
+              -> VM VExpr
+vectScalarFun forceScalar recFns expr
+ = do { gscalars <- globalScalars
+      ; let scalars = gscalars `extendVarSetList` recFns
+            (arg_tys, res_ty) = splitFunTys (exprType expr)
+      ; MASSERT( not $ null arg_tys )
+      ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
+                 ||
+                 all is_prim_ty arg_tys         -- check whether the function is scalar
+                  && is_prim_ty res_ty
+                  && is_scalar scalars expr
+                  && uses scalars expr)
+        $ mkScalarFun arg_tys res_ty expr
+      }
   where
-    arg_tys = map idType args
-    res_ty  = exprType body
-
-    is_scalar_ty ty 
+    -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
+    is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
-
         | otherwise = False
 
-    is_scalar vs (Var v)     = v `elemVarSet` vs
-    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
-    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
-    is_scalar _ _            = False
-
+    -- Checks whether an expression contain a non-scalar subexpression. 
+    --
+    -- Precodition: The variables in the first argument are scalar.
+    --
+    -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
+    -- them to the list of scalar variables) and then check them.  If one of them turns out not to
+    -- be scalar, the entire group is regarded as not being scalar.
+    --
+    -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
+    --        data constructor as scalar.  Should be changed once scalar types are passed
+    --        through VectInfo.
+    --
+    is_scalar :: VarSet -> CoreExpr -> Bool
+    is_scalar scalars  (Var v)         = v `elemVarSet` scalars
+    is_scalar _scalars (Lit _)         = True
+    is_scalar scalars  e@(App e1 e2) 
+      | maybe_parr_ty  (exprType e)    = False
+      | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
+    is_scalar scalars  (Lam var body)  
+      | maybe_parr_ty  (varType var)   = False
+      | otherwise                      = is_scalar (scalars `extendVarSet` var) body
+    is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
+      where
+        (bindsAreScalar, scalars') = is_scalar_bind scalars bind
+    is_scalar scalars  (Case e var ty alts)
+      | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+      | otherwise                      = False
+      where
+        scalars' = scalars `extendVarSet` var
+    is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
+    is_scalar scalars  (Note _ e   )   = is_scalar scalars e
+    is_scalar _scalars (Type {})       = True
+    is_scalar _scalars (Coercion {})   = True
+
+    -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
+    is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
+    is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
+      where
+        (vars, es) = unzip bnds
+        scalars'   = scalars `extendVarSetList` vars
+
+    is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+
+    -- Checks whether the type might be a parallel array type.  In particular, if the outermost
+    -- constructor is a type family, we conservatively assume that it may be a parallel array type.
+    maybe_parr_ty :: Type -> Bool
+    maybe_parr_ty ty 
+      | Just ty'        <- coreView ty            = maybe_parr_ty ty'
+      | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
+    maybe_parr_ty _                               = False
+
+    -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
+    --        is called by some other function that is otherwise scalar, it would be very bad
+    --        that just this call to the identity makes it not be scalar.
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
     -- (\n# x -> x) which is what we want.
-    uses funs (Var v)     = v `elemVarSet` funs 
-    uses funs (App e1 e2) = uses funs e1 || uses funs e2
-    uses _ _              = False
-
+    uses funs (Var v)       = v `elemVarSet` funs 
+    uses funs (App e1 e2)   = uses funs e1 || uses funs e2
+    uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
+    uses funs (Let (NonRec _b letExpr) body) 
+                            = uses funs letExpr || uses funs  body
+    uses funs (Case e _eId _ty alts) 
+                            = uses funs e || any (uses_alt funs) alts
+    uses _ _                = False
+
+    uses_alt funs (_, _bs, e) = uses funs e 
+
+mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
+mkScalarFun arg_tys res_ty expr
+  = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
+       ; zipf    <- zipScalars arg_tys res_ty
+       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
+       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
+       ; lclo    <- liftPD (Var clo_var)
+       ; return (Var clo_var, lclo)
+       }
 
 -- | Vectorise a lambda abstraction.
-vectLam 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> VarSet               -- ^ The free variables in the body.
-       -> [Var]                -- ^ Binding variables.
-       -> CoreExprWithFVs      -- ^ Body of abstraction.
-       -> VM VExpr
-
-vectLam inline loop_breaker fvs bs body
- = do tyvars    <- localTyVars
+--
+vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
+        -> Bool             -- ^ Whether the binding is a loop breaker.
+        -> CoreExprWithFVs  -- ^ Body of abstraction.
+        -> VM VExpr
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+ = do let (bs, body) = collectAnnValBinders expr
+      tyvars    <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
                    unzip [(var, vv) | var <- varSetElems fvs
                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
@@ -254,11 +321,12 @@ vectLam inline loop_breaker fvs bs body
                          (LitAlt (mkMachInt 0), [], empty)])
 
       | otherwise = return (ve, le)
+vectLam _ _ _ = panic "vectLam"
  
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
-vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
+vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
                         (ppr $ deAnnotate e `mkTyApps` tys)