Cleaned up Expr and Vectorise
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index 1c2ee4c..9cd34e3 100644 (file)
@@ -3,10 +3,8 @@
 module Vectorise.Exp
        (vectPolyExpr)
 where
-import VectUtils
-import VectType
-import Vectorise.Utils.Closure
-import Vectorise.Utils.Hoisting
+import Vectorise.Utils
+import Vectorise.Type.Type
 import Vectorise.Var
 import Vectorise.Vect
 import Vectorise.Env
@@ -24,7 +22,7 @@ import Var
 import VarEnv
 import VarSet
 import Id
-import BasicTypes
+import BasicTypes( isLoopBreaker )
 import Literal
 import TysWiredIn
 import TysPrim
@@ -37,21 +35,22 @@ import Data.List
 -- | Vectorise a polymorphic expression.
 vectPolyExpr 
        :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
-                               --   binding is a loop breaker.
+                                   --   binding is a loop breaker.
+       -> [Var]                        
        -> CoreExprWithFVs
-       -> VM (Inline, VExpr)
+       -> VM (Inline, Bool, VExpr)
 
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, expr') <- vectPolyExpr loop_breaker expr
-      return (inline, vNote note expr')
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
+      return (inline, isScalarFn, vNote note expr')
 
-vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns expr
  = do
       arity <- polyArity tvs
       polyAbstract tvs $ \args ->
         do
-          (inline, mono') <- vectFnExpr False loop_breaker mono
-          return (addInlineArity inline arity,
+          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
+          return (addInlineArity inline arity, isScalarFn, 
                   mapVect (mkLams $ tvs ++ args) mono')
   where
     (tvs, mono) = collectAnnTypeBinders expr
@@ -113,12 +112,13 @@ vectExpr (_, AnnCase scrut bndr ty alts)
   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
   , isAlgTyCon tycon
   = vectAlgCase tycon ty_args scrut bndr ty alts
+  | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
   where
     scrut_ty = exprType (deAnnotate scrut)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
+      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -134,11 +134,11 @@ vectExpr (_, AnnLet (AnnRec bs) body)
 
     vect_rhs bndr rhs = localV
                       . inBind bndr
-                      . liftM snd
-                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+                      . liftM (\(_,_,z)->z)
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = liftM snd $ vectFnExpr True False e
+  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -146,40 +146,43 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
     (bs,body) = collectAnnValBinders e
 -}
 
-vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
+vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
 
 
 -- | Vectorise an expression with an outer lambda abstraction.
 vectFnExpr 
        :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
        -> Bool                 -- ^ Whether the binding is a loop breaker.
+       -> [Var]
        -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-       -> VM (Inline, VExpr)
+       -> VM (Inline, Bool, VExpr)
 
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV (isEmptyVarSet fvs)
-                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
-                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+  | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
+                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
+                `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
 
-vectFnExpr _ _ e = mark DontInline $ vectExpr e
+vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
 
-mark :: Inline -> VM a -> VM (Inline, a)
-mark b p = do { x <- p; return (b,x) }
+mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
+mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 
 
 -- | Vectorise a function where are the args have scalar type,
 --   that is Int, Float, Double etc.
 vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function.
+       :: [Var]        -- ^ Bound variables of function
+       -> [Var]
        -> CoreExpr     -- ^ Function body.
        -> VM VExpr
        
-vectScalarLam args body
- = do scalars <- globalScalars
-      onlyIfV (all is_scalar_ty arg_tys
-               && is_scalar_ty res_ty
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+      let scalars = unionVarSet (mkVarSet recFns) scalars'
+      onlyIfV (all is_prim_ty arg_tys
+               && is_prim_ty res_ty
                && is_scalar (extendVarSetList scalars args) body
                && uses scalars body)
         $ do
@@ -194,18 +197,63 @@ vectScalarLam args body
     arg_tys = map idType args
     res_ty  = exprType body
 
-    is_scalar_ty ty 
+    is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
 
         | otherwise = False
-
-    is_scalar vs (Var v)     = v `elemVarSet` vs
-    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
-    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
-    is_scalar _ _            = False
+    
+    cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
+         
+    maybe_parr_ty ty = maybe_parr_ty' [] ty
+      
+    maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
+    maybe_parr_ty' alreadySeen ty
+       | isPArrTyCon tycon     = True
+       | isPrimTyCon tycon     = False
+       | isAbstractTyCon tycon = True
+       | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
+       | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
+                             hasParrDataCon alreadySeen tycon
+       | otherwise = True
+       where
+         Just (tycon, args) = splitTyConApp_maybe ty 
+         
+         
+         hasParrDataCon alreadySeen tycon
+           | tycon `elem` alreadySeen = False  
+           | otherwise                =  
+               any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
+         
+    -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
+    -- an external (non data constructor) variable is used, or anonymous data constructor      
+    is_scalar vs e@(Var v) 
+      | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
+      | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
+    is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
+
+    is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
+                               is_scalar vs e1 && is_scalar vs e2    
+    is_scalar vs e@(Let (NonRec b letExpr) body) 
+                             = cantbe_parr_expr e &&
+                               is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
+    is_scalar vs  e@(Let (Rec bnds) body) 
+                             =  let vs' = extendVarSetList vs (map fst bnds)
+                                in cantbe_parr_expr e &&  
+                                   all (is_scalar vs') (map snd bnds) && is_scalar vs' body
+    is_scalar vs e@(Case eC eId ty alts)  
+                             = let vs' = extendVarSet vs eId
+                                  in cantbe_parr_expr e && 
+                                  is_prim_ty ty &&
+                                  is_scalar vs' eC   &&
+                                  (all (is_scalar_alt vs') alts)
+                                    
+    is_scalar _ _            =  False
+
+    is_scalar_alt vs (_, bs, e) 
+                             = is_scalar (extendVarSetList vs bs) e
 
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
@@ -213,8 +261,14 @@ vectScalarLam args body
     -- (\n# x -> x) which is what we want.
     uses funs (Var v)     = v `elemVarSet` funs 
     uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Let (NonRec _b letExpr) body) 
+                          = uses funs letExpr || uses funs  body
+    uses funs (Case e _eId _ty alts) 
+                          = uses funs e || any (uses_alt funs) alts
     uses _ _              = False
 
+    uses_alt funs (_, _bs, e)   
+                          = uses funs e 
 
 -- | Vectorise a lambda abstraction.
 vectLam 
@@ -260,7 +314,7 @@ vectLam inline loop_breaker fvs bs body
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
-vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
+vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
                         (ppr $ deAnnotate e `mkTyApps` tys)