Break up vectoriser builtins module
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
index ea69c4f..aad5144 100644 (file)
@@ -1,11 +1,14 @@
+{-# OPTIONS -fno-warn-missing-signatures #-}
 
 module Vectorise( vectorise )
 where
 
 import VectMonad
 import VectUtils
+import VectVar
 import VectType
-import VectCore
+import Vectorise.Vect
+import Vectorise.Env
 
 import HscTypes hiding      ( MonadThings(..) )
 
@@ -27,7 +30,7 @@ import Id
 import OccName
 import BasicTypes           ( isLoopBreaker )
 
-import Literal              ( Literal, mkMachInt )
+import Literal
 import TysWiredIn
 import TysPrim              ( intPrimTy )
 
@@ -37,56 +40,121 @@ import Util                 ( zipLazy )
 import Control.Monad
 import Data.List            ( sortBy, unzip4 )
 
+
+debug          = False
+dtrace s x     = if debug then pprTrace "Vectorise" s x else x
+
+-- | Vectorise a single module.
+--   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
-vectorise backend guts = do
-    hsc_env <- getHscEnv
-    liftIO $ vectoriseIO backend hsc_env guts
+vectorise backend guts 
+ = do hsc_env <- getHscEnv
+      liftIO $ vectoriseIO backend hsc_env guts
+
 
+-- | Vectorise a single monad, given its HscEnv (code gen environment).
 vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
 vectoriseIO backend hsc_env guts
-  = do
+ = do -- Get information about currently loaded external packages.
       eps <- hscEPS hsc_env
+
+      -- Combine vectorisation info from the current module, and external ones.
       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
+
+      -- Run the main VM computation.
       Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
       return (guts' { mg_vect_info = info' })
 
+
+-- | Vectorise a single module, in the VM monad.
 vectModule :: ModGuts -> VM ModGuts
 vectModule guts
-  = do
+ = do -- Vectorise the type environment.
+      -- This may add new TyCons and DataCons.
+      -- TODO: What new binds do we get back here?
       (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
 
+      -- TODO: What is this?
       let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
       updGEnv (setFamInstEnv fam_inst_env')
 
       -- dicts   <- mapM buildPADict pa_insts
       -- workers <- mapM vectDataConWorkers pa_insts
+
+      -- Vectorise all the top level bindings.
       binds'  <- mapM vectTopBind (mg_binds guts)
+
       return $ guts { mg_types        = types'
                     , mg_binds        = Rec tc_binds : binds'
                     , mg_fam_inst_env = fam_inst_env'
                     , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                     }
 
+
+-- | Try to vectorise a top-level binding.
+--   If it doesn't vectorise then return it unharmed.
+--
+--   For example, for the binding 
+--
+--   @  
+--      foo :: Int -> Int
+--      foo = \x -> x + x
+--   @
+--  
+--   we get
+--   @
+--      foo  :: Int -> Int
+--      foo  = \x -> vfoo $: x                  
+-- 
+--      v_foo :: Closure void vfoo lfoo
+--      v_foo = closure vfoo lfoo void        
+-- 
+--      vfoo :: Void -> Int -> Int
+--      vfoo = ...
+--
+--      lfoo :: PData Void -> PData Int -> PData Int
+--      lfoo = ...
+--   @ 
+--
+--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
+--   function foo, but takes an explicit environment.
+-- 
+--   @lfoo@ is the "lifted" version that works on arrays.
+--
+--   @v_foo@ combines both of these into a `Closure` that also contains the
+--   environment.
+--
+--   The original binding @foo@ is rewritten to call the vectorised version
+--   present in the closure.
+--
 vectTopBind :: CoreBind -> VM CoreBind
 vectTopBind b@(NonRec var expr)
-  = do
-      (inline, expr') <- vectTopRhs var expr
-      var' <- vectTopBinder var inline expr'
-      hs    <- takeHoisted
-      cexpr <- tryConvert var var' expr
+ = do
+      (inline, expr')  <- vectTopRhs var expr
+      var'             <- vectTopBinder var inline expr'
+
+      -- Vectorising the body may create other top-level bindings.
+      hs       <- takeHoisted
+
+      -- To get the same functionality as the original body we project
+      -- out its vectorised version from the closure.
+      cexpr    <- tryConvert var var' expr
+
       return . Rec $ (var, cexpr) : (var', expr') : hs
   `orElseV`
     return b
 
 vectTopBind b@(Rec bs)
-  = do
-      (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
-        do
-          vars' <- sequence [vectTopBinder var inline rhs
-                               | (var, ~(inline, rhs))
-                                 <- zipLazy vars (zip inlines rhss)]
-          (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
-          return (vars', inlines', exprs')
+ = do
+      (vars', _, exprs') 
+       <- fixV $ \ ~(_, inlines, rhss) ->
+            do vars' <- sequence [vectTopBinder var inline rhs
+                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
+               (inlines', exprs') 
+                     <- mapAndUnzipM (uncurry vectTopRhs) bs
+
+               return (vars', inlines', exprs')
+
       hs     <- takeHoisted
       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -95,121 +163,83 @@ vectTopBind b@(Rec bs)
   where
     (vars, exprs) = unzip bs
 
--- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--- used inside of fixV in vectTopBind
-vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
+
+-- | Make the vectorised version of this top level binder, and add the mapping
+--   between it and the original to the state. For some binder @foo@ the vectorised
+--   version is @$v_foo@
+--
+--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
+--   used inside of fixV in vectTopBind
+vectTopBinder 
+       :: Var          -- ^ Name of the binding.
+       -> Inline       -- ^ Whether it should be inlined, used to annotate it.
+       -> CoreExpr     -- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
+       -> VM Var       -- ^ Name of the vectorised binding.
+
 vectTopBinder var inline expr
-  = do
+ = do
+      -- Vectorise the type attached to the var.
       vty  <- vectType (idType var)
-      var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
+
+      -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
+      var' <- liftM (`setIdUnfolding` unfolding) 
+           $  cloneId mkVectOcc var vty
+
+      -- Add the mapping between the plain and vectorised name to the state.
       defGlobalVar var var'
+
       return var'
   where
     unfolding = case inline of
                   Inline arity -> mkInlineRule expr (Just arity)
                   DontInline   -> noUnfolding
 
-vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
+
+-- | Vectorise the RHS of a top-level binding, in an empty local environment.
+vectTopRhs 
+       :: Var          -- ^ Name of the binding.
+       -> CoreExpr     -- ^ Body of the binding.
+       -> VM (Inline, CoreExpr)
+
 vectTopRhs var expr
-  = closedV
-  $ do
-      (inline, vexpr) <- inBind var
-                       $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+ = dtrace (vcat [text "vectTopRhs", ppr expr])
+ $ closedV
+ $ do (inline, vexpr) <- inBind var
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
                                       (freeVars expr)
       return (inline, vectorised vexpr)
 
-tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
-tryConvert var vect_var rhs
-  = fromVect (idType var) (Var vect_var) `orElseV` return rhs
 
--- ----------------------------------------------------------------------------
--- Bindings
+-- | Project out the vectorised version of a binding from some closure,
+--     or return the original body if that doesn't work.       
+tryConvert 
+       :: Var          -- ^ Name of the original binding (eg @foo@)
+       -> Var          -- ^ Name of vectorised version of binding (eg @$vfoo@)
+       -> CoreExpr     -- ^ The original body of the binding.
+       -> VM CoreExpr
 
-vectBndr :: Var -> VM VVar
-vectBndr v
-  = do
-      (vty, lty) <- vectAndLiftType (idType v)
-      let vv = v `Id.setIdType` vty
-          lv = v `Id.setIdType` lty
-      updLEnv (mapTo vv lv)
-      return (vv, lv)
-  where
-    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
+tryConvert var vect_var rhs
+  = fromVect (idType var) (Var vect_var) `orElseV` return rhs
 
-vectBndrNew :: Var -> FastString -> VM VVar
-vectBndrNew v fs
-  = do
-      vty <- vectType (idType v)
-      vv  <- newLocalVVar fs vty
-      updLEnv (upd vv)
-      return vv
-  where
-    upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }
-
-vectBndrIn :: Var -> VM a -> VM (VVar, a)
-vectBndrIn v p
-  = localV
-  $ do
-      vv <- vectBndr v
-      x <- p
-      return (vv, x)
-
-vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
-vectBndrNewIn v fs p
-  = localV
-  $ do
-      vv <- vectBndrNew v fs
-      x  <- p
-      return (vv, x)
-
-vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
-vectBndrsIn vs p
-  = localV
-  $ do
-      vvs <- mapM vectBndr vs
-      x <- p
-      return (vvs, x)
 
 -- ----------------------------------------------------------------------------
 -- Expressions
 
-vectVar :: Var -> VM VExpr
-vectVar v
-  = do
-      r <- lookupVar v
-      case r of
-        Local (vv,lv) -> return (Var vv, Var lv)
-        Global vv     -> do
-                           let vexpr = Var vv
-                           lexpr <- liftPD vexpr
-                           return (vexpr, lexpr)
-
-vectPolyVar :: Var -> [Type] -> VM VExpr
-vectPolyVar v tys
-  = do
-      vtys <- mapM vectType tys
-      r <- lookupVar v
-      case r of
-        Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
-                                     (polyApply (Var lv) vtys)
-        Global poly    -> do
-                            vexpr <- polyApply (Var poly) vtys
-                            lexpr <- liftPD vexpr
-                            return (vexpr, lexpr)
-
-vectLiteral :: Literal -> VM VExpr
-vectLiteral lit
-  = do
-      lexpr <- liftPD (Lit lit)
-      return (Lit lit, lexpr)
 
-vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+-- | Vectorise a polymorphic expression
+vectPolyExpr 
+       :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
+                               --   binding is a loop breaker.
+       -> CoreExprWithFVs
+       -> VM (Inline, VExpr)
+
 vectPolyExpr loop_breaker (_, AnnNote note expr)
-  = do
-      (inline, expr') <- vectPolyExpr loop_breaker expr
+ = do (inline, expr') <- vectPolyExpr loop_breaker expr
       return (inline, vNote note expr')
+
 vectPolyExpr loop_breaker expr
-  = do
+ = dtrace (vcat [text "vectPolyExpr", ppr (deAnnotate expr)])
+ $ do
       arity <- polyArity tvs
       polyAbstract tvs $ \args ->
         do
@@ -219,13 +249,17 @@ vectPolyExpr loop_breaker expr
   where
     (tvs, mono) = collectAnnTypeBinders expr
 
+
+-- | Vectorise a core expression.
 vectExpr :: CoreExprWithFVs -> VM VExpr
 vectExpr (_, AnnType ty)
   = liftM vType (vectType ty)
 
-vectExpr (_, AnnVar v) = vectVar v
+vectExpr (_, AnnVar v) 
+  = vectVar v
 
-vectExpr (_, AnnLit lit) = vectLiteral lit
+vectExpr (_, AnnLit lit) 
+  = vectLiteral lit
 
 vectExpr (_, AnnNote note expr)
   = liftM (vNote note) (vectExpr expr)
@@ -247,12 +281,27 @@ vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
 
 
+-- TODO: Avoid using closure application for dictionaries.
+-- vectExpr (_, AnnApp fn arg)
+--  | if is application of dictionary 
+--    just use regular app instead of closure app.
+
+-- for lifted version. 
+--      do liftPD (sub a dNumber)
+--      lift the result of the selection, not sub and dNumber seprately. 
+
 vectExpr (_, AnnApp fn arg)
-  = do
+ = dtrace (text "AnnApp" <+> ppr (deAnnotate fn) <+> ppr (deAnnotate arg))
+ $ do
       arg_ty' <- vectType arg_ty
       res_ty' <- vectType res_ty
+
+      dtrace (text "vectorising fn " <> ppr (deAnnotate fn))  $ return ()
       fn'     <- vectExpr fn
+      dtrace (text "fn' = "       <> ppr fn') $ return ()
+
       arg'    <- vectExpr arg
+
       mkClosureApp arg_ty' res_ty' fn' arg'
   where
     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
@@ -296,44 +345,58 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
 
 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
 
-vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+
+-- | Vectorise an expression with an outer lambda abstraction.
+vectFnExpr 
+       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
+       -> Bool                 -- ^ Whether the binding is a loop breaker.
+       -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
+       -> VM (Inline, VExpr)
+
 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
   | isId bndr = onlyIfV (isEmptyVarSet fvs)
                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
+
 vectFnExpr _ _ e = mark DontInline $ vectExpr e
 
 mark :: Inline -> VM a -> VM (Inline, a)
 mark b p = do { x <- p; return (b,x) }
 
-vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
+
+-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
+vectScalarLam 
+       :: [Var]        -- ^ Bound variables of function.
+       -> CoreExpr     -- ^ Function body.
+       -> VM VExpr
 vectScalarLam args body
-  = do
-      scalars <- globalScalars
+ = dtrace (vcat [text "vectScalarLam ", ppr args, ppr body])
+ $ do scalars <- globalScalars
       onlyIfV (all is_scalar_ty arg_tys
                && is_scalar_ty res_ty
                && is_scalar (extendVarSetList scalars args) body
                && uses scalars body)
         $ do
-            fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
-            zipf <- zipScalars arg_tys res_ty
-            clo <- scalarClosure arg_tys res_ty (Var fn_var)
+            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
+            zipf    <- zipScalars arg_tys res_ty
+            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
                                                 (zipf `App` Var fn_var)
             clo_var <- hoistExpr (fsLit "clo") clo DontInline
-            lclo <- liftPD (Var clo_var)
+            lclo    <- liftPD (Var clo_var)
             return (Var clo_var, lclo)
   where
     arg_tys = map idType args
     res_ty  = exprType body
 
-    is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
-                    = tycon == intTyCon
-                      || tycon == floatTyCon
-                      || tycon == doubleTyCon
+    is_scalar_ty ty 
+        | Just (tycon, [])   <- splitTyConApp_maybe ty
+        =    tycon == intTyCon
+          || tycon == floatTyCon
+          || tycon == doubleTyCon
 
-                    | otherwise = False
+        | otherwise = False
 
     is_scalar vs (Var v)     = v `elemVarSet` vs
     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
@@ -348,23 +411,42 @@ vectScalarLam args body
     uses funs (App e1 e2) = uses funs e1 || uses funs e2
     uses _ _              = False
 
-vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
+
+vectLam 
+       :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
+       -> Bool                 -- ^ Whether the binding is a loop breaker.
+       -> VarSet               -- ^ The free variables in the body.
+       -> [Var]                -- 
+       -> CoreExprWithFVs
+       -> VM VExpr
+
 vectLam inline loop_breaker fvs bs body
-  = do
-      tyvars <- localTyVars
+ = dtrace (vcat [ text "vectLam "
+               , text "free vars    = " <> ppr fvs
+               , text "binding vars = " <> ppr bs
+               , text "body         = " <> ppr (deAnnotate body)])
+
+ $ do tyvars    <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
                    unzip [(var, vv) | var <- varSetElems fvs
                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
 
-      arg_tys <- mapM (vectType . idType) bs
-      res_ty  <- vectType (exprType $ deAnnotate body)
+      arg_tys   <- mapM (vectType . idType) bs
+
+      dtrace (text "arg_tys = " <> ppr arg_tys) $ return ()
+
+      res_ty    <- vectType (exprType $ deAnnotate body)
+
+      dtrace (text "res_ty = " <> ppr res_ty) $ return ()
 
       buildClosures tyvars vvs arg_tys res_ty
         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
         $ do
-            lc <- builtin liftingContext
-            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
-                                           (vectExpr body)
+            lc              <- builtin liftingContext
+            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
+
+            dtrace (text "vbody = " <> ppr vbody) $ return ()
+
             vbody' <- break_loop lc res_ty vbody
             return $ vLams lc vbndrs vbody'
   where