vectScalarLam handles int, float, and double now
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index 42efe37..4e07086 100644 (file)
@@ -197,13 +197,32 @@ vectScalarLam args body
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
+          || tycon == boolTyCon
 
         | otherwise = False
 
     is_scalar vs (Var v)     = v `elemVarSet` vs
     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
-    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
-    is_scalar _ _            = False
+    
+    is_scalar _ (App (Var v) (Lit lit)) 
+       | Just con <- isDataConId_maybe v = con `elem` [intDataCon, floatDataCon, doubleDataCon]
+
+    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2    
+    is_scalar vs (Let (NonRec b letExpr) body) 
+                             = is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
+    is_scalar vs (Let (Rec bnds) body) 
+                             =  let vs' = extendVarSetList vs (map fst bnds)
+                                in all (is_scalar vs') (map snd bnds) && is_scalar vs' body
+    is_scalar vs (Case e eId ty alts)  
+                             = let vs' = extendVarSet vs eId
+                                  in is_scalar_ty ty &&
+                                  is_scalar vs' e   &&
+                                  (all (is_scalar_alt vs') alts)
+                                    
+    is_scalar _ e            = False
+
+    is_scalar_alt vs (_, bs, e) 
+                             = is_scalar (extendVarSetList vs bs) e
 
     -- A scalar function has to actually compute something. Without the check,
     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
@@ -211,8 +230,14 @@ vectScalarLam args body
     -- (\n# x -> x) which is what we want.
     uses funs (Var v)     = v `elemVarSet` funs 
     uses funs (App e1 e2) = uses funs e1 || uses funs e2
+    uses funs (Let (NonRec b letExpr) body) 
+                          = uses funs letExpr || uses funs  body
+    uses funs (Case e eId ty alts) 
+                          = uses funs e || any (uses_alt funs) alts
     uses _ _              = False
 
+    uses_alt funs (_, bs, e)   
+                          = uses funs e 
 
 -- | Vectorise a lambda abstraction.
 vectLam