update for changes in hetmet Makefile
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
index b94224a..569057e 100644 (file)
@@ -33,22 +33,21 @@ import Data.List
 
 
 -- | Vectorise a polymorphic expression.
-vectPolyExpr 
-       :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
-                               --   binding is a loop breaker.
-       -> CoreExprWithFVs
-       -> VM (Inline, Bool, VExpr)
-
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
+--
+vectPolyExpr :: Bool           -- ^ When vectorising the RHS of a binding, whether that
+                                             --   binding is a loop breaker.
+                  -> [Var]                     
+                  -> CoreExprWithFVs
+                  -> VM (Inline, Bool, VExpr)
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
       return (inline, isScalarFn, vNote note expr')
-
-vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns expr
  = do
       arity <- polyArity tvs
       polyAbstract tvs $ \args ->
         do
-          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
+          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
           return (addInlineArity inline arity, isScalarFn, 
                   mapVect (mkLams $ tvs ++ args) mono')
   where
@@ -117,7 +116,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
+      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -134,10 +133,10 @@ vectExpr (_, AnnLet (AnnRec bs) body)
     vect_rhs bndr rhs = localV
                       . inBind bndr
                       . liftM (\(_,_,z)->z)
-                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
+  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -147,23 +146,20 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
 
 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
 
-
 -- | Vectorise an expression with an outer lambda abstraction.
-vectFnExpr 
-       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
-       -> Bool                 -- ^ Whether the binding is a loop breaker.
-       -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-       -> VM (Inline, Bool, VExpr)
-
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
-  | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
-                 onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
-                        (mark DontInline True . vectScalarLam bs $ deAnnotate body)
+--
+vectFnExpr :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
+           -> Bool             -- ^ Whether the binding is a loop breaker.
+           -> [Var]
+           -> CoreExprWithFVs  -- ^ Expression to vectorise. Must have an outer `AnnLam`.
+           -> VM (Inline, Bool, VExpr)
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+  | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
+                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
-
-vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
+vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
 
 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
@@ -172,14 +168,15 @@ mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 -- | Vectorise a function where are the args have scalar type,
 --   that is Int, Float, Double etc.
 vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function.
+       :: [Var]        -- ^ Bound variables of function
+       -> [Var]
        -> CoreExpr     -- ^ Function body.
        -> VM VExpr
        
-vectScalarLam args body
- = do scalars <- globalScalars
-      pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
-        onlyIfV (all is_prim_ty arg_tys
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+      let scalars = unionVarSet (mkVarSet recFns) scalars'
+      onlyIfV (all is_prim_ty arg_tys
                && is_prim_ty res_ty
                && is_scalar (extendVarSetList scalars args) body
                && uses scalars body)
@@ -190,8 +187,7 @@ vectScalarLam args body
                                                 (zipf `App` Var fn_var)
             clo_var <- hoistExpr (fsLit "clo") clo DontInline
             lclo    <- liftPD (Var clo_var)
-            pprTrace "  lam is scalar" (ppr "") $
-              return (Var clo_var, lclo)
+            return (Var clo_var, lclo)
   where
     arg_tys = map idType args
     res_ty  = exprType body
@@ -214,8 +210,7 @@ vectScalarLam args body
        | isPrimTyCon tycon     = False
        | isAbstractTyCon tycon = True
        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
-       | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ 
-                             any (maybe_parr_ty' alreadySeen) args || 
+       | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
                              hasParrDataCon alreadySeen tycon
        | otherwise = True
        where
@@ -232,31 +227,25 @@ vectScalarLam args body
     is_scalar vs e@(Var v) 
       | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
       | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
-    is_scalar _ e@(Lit _)    = -- pprTrace "is_scalar  Lit" (ppr e) $ 
-                               cantbe_parr_expr e  
+    is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
 
-    is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar  App" (ppr e) $  
-                               cantbe_parr_expr e &&
+    is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
                                is_scalar vs e1 && is_scalar vs e2    
     is_scalar vs e@(Let (NonRec b letExpr) body) 
-                             = -- pprTrace "is_scalar  Let" (ppr e) $  
-                               cantbe_parr_expr e &&
+                             = cantbe_parr_expr e &&
                                is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
-    is_scalar vs e@(Let (Rec bnds) body) 
+    is_scalar vs  e@(Let (Rec bnds) body) 
                              =  let vs' = extendVarSetList vs (map fst bnds)
-                                in -- pprTrace "is_scalar  Rec" (ppr e) $  
-                                   cantbe_parr_expr e &&  
+                                in cantbe_parr_expr e &&  
                                    all (is_scalar vs') (map snd bnds) && is_scalar vs' body
     is_scalar vs e@(Case eC eId ty alts)  
                              = let vs' = extendVarSet vs eId
-                                  in -- pprTrace "is_scalar  Case" (ppr e) $ 
-                                     cantbe_parr_expr e && 
+                                  in cantbe_parr_expr e && 
                                   is_prim_ty ty &&
                                   is_scalar vs' eC   &&
                                   (all (is_scalar_alt vs') alts)
                                     
-    is_scalar _ e            =  -- pprTrace "is_scalar  other" (ppr e) $  
-                                False
+    is_scalar _ _            =  False
 
     is_scalar_alt vs (_, bs, e) 
                              = is_scalar (extendVarSetList vs bs) e