Handling of recursive scalar functions in isScalarLam
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 {-# OPTIONS -fno-warn-missing-signatures #-}
2
3 module Vectorise( vectorise )
4 where
5
6 import Vectorise.Type.Env
7 import Vectorise.Type.Type
8 import Vectorise.Convert
9 import Vectorise.Utils.Hoisting
10 import Vectorise.Exp
11 import Vectorise.Vect
12 import Vectorise.Env
13 import Vectorise.Monad
14
15 import HscTypes hiding      ( MonadThings(..) )
16 import Module               ( PackageId )
17 import CoreSyn
18 import CoreUnfold           ( mkInlineUnfolding )
19 import CoreFVs
20 import CoreMonad            ( CoreM, getHscEnv )
21 import Var
22 import Id
23 import OccName
24 import BasicTypes           ( isLoopBreaker )
25 import Outputable
26 import Util                 ( zipLazy )
27 import MonadUtils
28
29 import Control.Monad
30
31 debug           = False
32 dtrace s x      = if debug then pprTrace "Vectorise" s x else x
33
34 -- | Vectorise a single module.
35 --   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
36 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
37 vectorise backend guts 
38  = do hsc_env <- getHscEnv
39       liftIO $ vectoriseIO backend hsc_env guts
40
41
42 -- | Vectorise a single monad, given its HscEnv (code gen environment).
43 vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
44 vectoriseIO backend hsc_env guts
45  = do -- Get information about currently loaded external packages.
46       eps <- hscEPS hsc_env
47
48       -- Combine vectorisation info from the current module, and external ones.
49       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
50
51       -- Run the main VM computation.
52       Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
53       return (guts' { mg_vect_info = info' })
54
55
56 -- | Vectorise a single module, in the VM monad.
57 vectModule :: ModGuts -> VM ModGuts
58 vectModule guts
59  = do -- Vectorise the type environment.
60       -- This may add new TyCons and DataCons.
61       -- TODO: What new binds do we get back here?
62       (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
63
64       (_, fam_inst_env) <- readGEnv global_fam_inst_env
65
66       -- dicts   <- mapM buildPADict pa_insts
67       -- workers <- mapM vectDataConWorkers pa_insts
68
69       -- Vectorise all the top level bindings.
70       binds'  <- mapM vectTopBind (mg_binds guts)
71
72       return $ guts { mg_types        = types'
73                     , mg_binds        = Rec tc_binds : binds'
74                     , mg_fam_inst_env = fam_inst_env
75                     , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
76                     }
77
78
79 -- | Try to vectorise a top-level binding.
80 --   If it doesn't vectorise then return it unharmed.
81 --
82 --   For example, for the binding 
83 --
84 --   @  
85 --      foo :: Int -> Int
86 --      foo = \x -> x + x
87 --   @
88 --  
89 --   we get
90 --   @
91 --      foo  :: Int -> Int
92 --      foo  = \x -> vfoo $: x                  
93 -- 
94 --      v_foo :: Closure void vfoo lfoo
95 --      v_foo = closure vfoo lfoo void        
96 -- 
97 --      vfoo :: Void -> Int -> Int
98 --      vfoo = ...
99 --
100 --      lfoo :: PData Void -> PData Int -> PData Int
101 --      lfoo = ...
102 --   @ 
103 --
104 --   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
105 --   function foo, but takes an explicit environment.
106 -- 
107 --   @lfoo@ is the "lifted" version that works on arrays.
108 --
109 --   @v_foo@ combines both of these into a `Closure` that also contains the
110 --   environment.
111 --
112 --   The original binding @foo@ is rewritten to call the vectorised version
113 --   present in the closure.
114 --
115 vectTopBind :: CoreBind -> VM CoreBind
116 vectTopBind b@(NonRec var expr)
117  = do
118       (inline, _, expr')        <- vectTopRhs [] var expr
119       var'              <- vectTopBinder var inline expr'
120
121       -- Vectorising the body may create other top-level bindings.
122       hs        <- takeHoisted
123
124       -- To get the same functionality as the original body we project
125       -- out its vectorised version from the closure.
126       cexpr     <- tryConvert var var' expr
127
128       return . Rec $ (var, cexpr) : (var', expr') : hs
129   `orElseV`
130     return b
131
132 vectTopBind b@(Rec bs)
133  = do
134       -- pprTrace "in Rec" (ppr vars) $ return ()
135       (vars', _, exprs') 
136         <- fixV $ \ ~(_, inlines, rhss) ->
137             do vars' <- sequence [vectTopBinder var inline rhs
138                                       | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
139                (inlines', areScalars', exprs') 
140                      <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
141                if  (and areScalars') || (length bs <= 1)
142                   then do
143                     -- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return ()
144                     return (vars', inlines', exprs')
145                   else do
146                     -- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return ()
147                     mapM deleteGlobalScalar vars
148                     (inlines'', _, exprs'')  <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
149                     return (vars', inlines'', exprs'')
150                       
151       hs     <- takeHoisted
152       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
153       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
154   `orElseV`
155     return b
156   where
157     (vars, exprs) = unzip bs
158     mapAndUnzip3M f xs = do
159        ys <- mapM f xs
160        return $ unzip3 ys
161
162 -- | Make the vectorised version of this top level binder, and add the mapping
163 --   between it and the original to the state. For some binder @foo@ the vectorised
164 --   version is @$v_foo@
165 --
166 --   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
167 --   used inside of fixV in vectTopBind
168 vectTopBinder 
169         :: Var          -- ^ Name of the binding.
170         -> Inline       -- ^ Whether it should be inlined, used to annotate it.
171         -> CoreExpr     -- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
172         -> VM Var       -- ^ Name of the vectorised binding.
173
174 vectTopBinder var inline expr
175  = do
176       -- Vectorise the type attached to the var.
177       vty  <- vectType (idType var)
178
179       -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
180       var' <- liftM (`setIdUnfoldingLazily` unfolding) 
181            $  cloneId mkVectOcc var vty
182
183       -- Add the mapping between the plain and vectorised name to the state.
184       defGlobalVar var var'
185
186       return var'
187   where
188     unfolding = case inline of
189                   Inline arity -> mkInlineUnfolding (Just arity) expr
190                   DontInline   -> noUnfolding
191
192
193 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
194 vectTopRhs 
195         :: [Var]    -- ^ Names of all functions in the rec block
196         -> Var          -- ^ Name of the binding.
197         -> CoreExpr     -- ^ Body of the binding.
198         -> VM (Inline, Bool, CoreExpr)
199
200 vectTopRhs recFs var expr
201  = dtrace (vcat [text "vectTopRhs", ppr expr])
202  $ closedV
203  $ do (inline, isScalar, vexpr) <- inBind var
204                       -- $ pprTrace "vectTopRhs" (ppr var)
205                       $ vectPolyExpr  (isLoopBreaker $ idOccInfo var) recFs
206                                       (freeVars expr)
207       if isScalar 
208          then addGlobalScalar var
209          else deleteGlobalScalar var
210       return (inline, isScalar, vectorised vexpr)
211
212
213 -- | Project out the vectorised version of a binding from some closure,
214 --      or return the original body if that doesn't work.       
215 tryConvert 
216         :: Var          -- ^ Name of the original binding (eg @foo@)
217         -> Var          -- ^ Name of vectorised version of binding (eg @$vfoo@)
218         -> CoreExpr     -- ^ The original body of the binding.
219         -> VM CoreExpr
220
221 tryConvert var vect_var rhs
222   = fromVect (idType var) (Var vect_var) `orElseV` return rhs
223