Handling of recursive scalar functions in isScalarLam
authorkeller@cse.unsw.edu.au <unknown>
Mon, 14 Feb 2011 00:29:45 +0000 (00:29 +0000)
committerkeller@cse.unsw.edu.au <unknown>
Mon, 14 Feb 2011 00:29:45 +0000 (00:29 +0000)
compiler/vectorise/Vectorise.hs
compiler/vectorise/Vectorise/Exp.hs
compiler/vectorise/Vectorise/Monad.hs

index 8c9579e..999e8ef 100644 (file)
@@ -115,7 +115,7 @@ vectModule guts
 vectTopBind :: CoreBind -> VM CoreBind
 vectTopBind b@(NonRec var expr)
  = do
-      (inline, expr')  <- vectTopRhs var expr
+      (inline, _, expr')       <- vectTopRhs [] var expr
       var'             <- vectTopBinder var inline expr'
 
       -- Vectorising the body may create other top-level bindings.
@@ -131,15 +131,23 @@ vectTopBind b@(NonRec var expr)
 
 vectTopBind b@(Rec bs)
  = do
+      -- pprTrace "in Rec" (ppr vars) $ return ()
       (vars', _, exprs') 
        <- fixV $ \ ~(_, inlines, rhss) ->
             do vars' <- sequence [vectTopBinder var inline rhs
                                       | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
-               (inlines', exprs') 
-                     <- mapAndUnzipM (uncurry vectTopRhs) bs
-
-               return (vars', inlines', exprs')
-
+               (inlines', areScalars', exprs') 
+                     <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
+               if  (and areScalars') || (length bs <= 1)
+                  then do
+                    -- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return ()
+                    return (vars', inlines', exprs')
+                  else do
+                    -- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return ()
+                    mapM deleteGlobalScalar vars
+                    (inlines'', _, exprs'')  <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
+                    return (vars', inlines'', exprs'')
+                      
       hs     <- takeHoisted
       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -147,7 +155,9 @@ vectTopBind b@(Rec bs)
     return b
   where
     (vars, exprs) = unzip bs
-
+    mapAndUnzip3M f xs = do
+       ys <- mapM f xs
+       return $ unzip3 ys
 
 -- | Make the vectorised version of this top level binder, and add the mapping
 --   between it and the original to the state. For some binder @foo@ the vectorised
@@ -182,21 +192,22 @@ vectTopBinder var inline expr
 
 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
 vectTopRhs 
-       :: Var          -- ^ Name of the binding.
+       :: [Var]    -- ^ Names of all functions in the rec block
+       -> Var          -- ^ Name of the binding.
        -> CoreExpr     -- ^ Body of the binding.
-       -> VM (Inline, CoreExpr)
+       -> VM (Inline, Bool, CoreExpr)
 
-vectTopRhs var expr
+vectTopRhs recFs var expr
  = dtrace (vcat [text "vectTopRhs", ppr expr])
  $ closedV
  $ do (inline, isScalar, vexpr) <- inBind var
-                      $ pprTrace "vectTopRhs" (ppr var)
-                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+                      -- $ pprTrace "vectTopRhs" (ppr var)
+                      $ vectPolyExpr  (isLoopBreaker $ idOccInfo var) recFs
                                       (freeVars expr)
       if isScalar 
          then addGlobalScalar var
-         else return ()
-      return (inline, vectorised vexpr)
+         else deleteGlobalScalar var
+      return (inline, isScalar, vectorised vexpr)
 
 
 -- | Project out the vectorised version of a binding from some closure,
index b94224a..091a760 100644 (file)
@@ -35,20 +35,21 @@ import Data.List
 -- | Vectorise a polymorphic expression.
 vectPolyExpr 
        :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
-                               --   binding is a loop breaker.
+                                   --   binding is a loop breaker.
+       -> [Var]                        
        -> CoreExprWithFVs
        -> VM (Inline, Bool, VExpr)
 
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
       return (inline, isScalarFn, vNote note expr')
 
-vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns expr
  = do
       arity <- polyArity tvs
       polyAbstract tvs $ \args ->
         do
-          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
+          (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
           return (addInlineArity inline arity, isScalarFn, 
                   mapVect (mkLams $ tvs ++ args) mono')
   where
@@ -117,7 +118,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
+      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -134,10 +135,10 @@ vectExpr (_, AnnLet (AnnRec bs) body)
     vect_rhs bndr rhs = localV
                       . inBind bndr
                       . liftM (\(_,_,z)->z)
-                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
+  | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -152,18 +153,19 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno
 vectFnExpr 
        :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
        -> Bool                 -- ^ Whether the binding is a loop breaker.
+       -> [Var]
        -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
        -> VM (Inline, Bool, VExpr)
 
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
-  | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+  | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
                  onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
-                        (mark DontInline True . vectScalarLam bs $ deAnnotate body)
+                        (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
 
-vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
+vectFnExpr _ _ _  e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
 
 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
@@ -172,13 +174,18 @@ mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
 -- | Vectorise a function where are the args have scalar type,
 --   that is Int, Float, Double etc.
 vectScalarLam 
-       :: [Var]        -- ^ Bound variables of function.
+       :: [Var]        -- ^ Bound variables of function
+       -> [Var]
        -> CoreExpr     -- ^ Function body.
        -> VM VExpr
        
-vectScalarLam args body
- = do scalars <- globalScalars
-      pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+      let scalars = unionVarSet (mkVarSet recFns) scalars'
+{-      pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
+        pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
+        pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+        pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
         onlyIfV (all is_prim_ty arg_tys
                && is_prim_ty res_ty
                && is_scalar (extendVarSetList scalars args) body
@@ -190,7 +197,7 @@ vectScalarLam args body
                                                 (zipf `App` Var fn_var)
             clo_var <- hoistExpr (fsLit "clo") clo DontInline
             lclo    <- liftPD (Var clo_var)
-            pprTrace "  lam is scalar" (ppr "") $
+            {- pprTrace "  lam is scalar" (ppr "") $ -}
               return (Var clo_var, lclo)
   where
     arg_tys = map idType args
@@ -214,7 +221,7 @@ vectScalarLam args body
        | isPrimTyCon tycon     = False
        | isAbstractTyCon tycon = True
        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
-       | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ 
+       | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $ 
                              any (maybe_parr_ty' alreadySeen) args || 
                              hasParrDataCon alreadySeen tycon
        | otherwise = True
index 77b9b7f..2597430 100644 (file)
@@ -17,7 +17,8 @@ module Vectorise.Monad (
        maybeCantVectoriseVarM,
        dumpVar,
        addGlobalScalar, 
-
+    deleteGlobalScalar,
+    
        -- * Primitives
        lookupPrimPArray,
        lookupPrimMethod
@@ -146,6 +147,11 @@ addGlobalScalar :: Var -> VM ()
 addGlobalScalar var 
   = updGEnv $ \env -> pprTrace "addGLobalScalar" (ppr var) env{global_scalars = extendVarSet (global_scalars env) var}
      
+deleteGlobalScalar :: Var -> VM ()
+deleteGlobalScalar var 
+  = updGEnv $ \env -> pprTrace "deleteGLobalScalar" (ppr var) env{global_scalars = delVarSet (global_scalars env) var}
+     
+     
 -- Primitives -----------------------------------------------------------------
 lookupPrimPArray :: TyCon -> VM (Maybe TyCon)
 lookupPrimPArray = liftBuiltinDs . primPArray