Handling of lets, letrec and case when checking if a lambda expr needs to be vectorised
authorkeller@cse.unsw.edu.au <unknown>
Mon, 15 Nov 2010 05:12:25 +0000 (05:12 +0000)
committerkeller@cse.unsw.edu.au <unknown>
Mon, 15 Nov 2010 05:12:25 +0000 (05:12 +0000)
compiler/vectorise/Vectorise/Exp.hs

index 42efe37..d00b040 100644 (file)
@@ -197,13 +197,28 @@ vectScalarLam args body
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
+          || tycon == boolTyCon
 
         | otherwise = False
 
     is_scalar vs (Var v)     = v `elemVarSet` vs
     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
-    is_scalar _ _            = False
+    is_scalar vs (Let (NonRec b letExpr) body) 
+                             = is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
+    is_scalar vs (Let (Rec bnds) body) 
+                             =  let vs' = extendVarSetList vs (map fst bnds)
+                                in all (is_scalar vs') (map snd bnds) && is_scalar vs' body
+    is_scalar vs (Case e eId ty alts)  
+                             = let vs' = extendVarSet vs eId
+                                  in is_scalar_ty ty &&
+                                  is_scalar vs' e   &&
+                                  (all (is_scalar_alt vs') alts)
+
+    is_scalar _ e            = False
+
+    is_scalar_alt vs (_, bs, e) 
+                             = is_scalar (extendVarSetList vs bs) e
 
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
@@ -211,8 +226,14 @@ vectScalarLam args body
     -- (\n# x -> x) which is what we want.
     uses funs (Var v)     = v `elemVarSet` funs 
     uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Let (NonRec b letExpr) body) 
+                          = uses funs letExpr || uses funs  body
+    uses funs (Case e eId ty alts) 
+                          = uses funs e || any (uses_alt funs) alts
     uses _ _              = False
 
+    uses_alt funs (_, bs, e)   
+                          = uses funs e 
 
 -- | Vectorise a lambda abstraction.
 vectLam