added handling of data constructors to vectLam
authorkeller@cse.unsw.edu.au <unknown>
Tue, 1 Feb 2011 04:28:07 +0000 (04:28 +0000)
committerkeller@cse.unsw.edu.au <unknown>
Tue, 1 Feb 2011 04:28:07 +0000 (04:28 +0000)
compiler/vectorise/Vectorise/Exp.hs

index 28ff4d8..c3793dc 100644 (file)
@@ -192,34 +192,65 @@ vectScalarLam args body
     arg_tys = map idType args
     res_ty  = exprType body
 
-    is_scalar_ty ty 
+    is_prim_ty ty 
         | Just (tycon, [])   <- splitTyConApp_maybe ty
         =    tycon == intTyCon
           || tycon == floatTyCon
           || tycon == doubleTyCon
-          || tycon == boolTyCon
 
         | otherwise = False
-
-    is_scalar vs (Var v)     = v `elemVarSet` vs
-    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
     
-    is_scalar _ (App (Var v) (Lit _)) 
-       | Just con <- isDataConId_maybe v = con `elem` [intDataCon, floatDataCon, doubleDataCon]
-
-    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2    
-    is_scalar vs (Let (NonRec b letExpr) body) 
-                             = is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
-    is_scalar vs (Let (Rec bnds) body) 
+    cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
+         
+    maybe_parr_ty ty = maybe_parr_ty' [] ty    
+    maybe_parr_ty' alreadySeen ty
+       | isPArrTyCon tycon     = True
+       | isPrimTyCon tycon     = False
+       | isAbstractTyCon tycon = True
+       | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
+       | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ 
+                             any (maybe_parr_ty' alreadySeen) args || 
+                             hasParrDataCon alreadySeen tycon
+       | otherwise = True
+       where
+         Just (tycon, args) = splitTyConApp_maybe ty 
+         
+         
+         hasParrDataCon alreadySeen tycon
+           | tycon `elem` alreadySeen = False  
+           | otherwise                =  
+               any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
+         
+    -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
+    -- an external (non data constructor) variable is used, or anonymous data constructor      
+    is_scalar vs e@(Var v) 
+      | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
+      | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
+    is_scalar _ e@(Lit _)    = -- pprTrace "is_scalar  Lit" (ppr e) $ 
+                               cantbe_parr_expr e  
+
+    is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar  App" (ppr e) $  
+                               cantbe_parr_expr e &&
+                               is_scalar vs e1 && is_scalar vs e2    
+    is_scalar vs e@(Let (NonRec b letExpr) body) 
+                             = -- pprTrace "is_scalar  Let" (ppr e) $  
+                               cantbe_parr_expr e &&
+                               is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
+    is_scalar vs e@(Let (Rec bnds) body) 
                              =  let vs' = extendVarSetList vs (map fst bnds)
-                                in all (is_scalar vs') (map snd bnds) && is_scalar vs' body
-    is_scalar vs (Case e eId ty alts)  
+                                in -- pprTrace "is_scalar  Rec" (ppr e) $  
+                                   cantbe_parr_expr e &&  
+                                   all (is_scalar vs') (map snd bnds) && is_scalar vs' body
+    is_scalar vs e@(Case eC eId ty alts)  
                              = let vs' = extendVarSet vs eId
-                                  in is_scalar_ty ty &&
-                                  is_scalar vs' e   &&
+                                  in -- pprTrace "is_scalar  Case" (ppr e) $ 
+                                     cantbe_parr_expr e 
+                                  is_prim_ty ty &&
+                                  is_scalar vs' eC   &&
                                   (all (is_scalar_alt vs') alts)
                                     
-    is_scalar _ _            = False
+    is_scalar _ e            =  -- pprTrace "is_scalar  other" (ppr e) $  
+                                False
 
     is_scalar_alt vs (_, bs, e) 
                              = is_scalar (extendVarSetList vs bs) e