add -fsimpleopt-before-flatten
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 {-# OPTIONS -fno-warn-missing-signatures -fno-warn-unused-do-bind #-}
2
3 module Vectorise ( vectorise )
4 where
5
6 import Vectorise.Type.Env
7 import Vectorise.Type.Type
8 import Vectorise.Convert
9 import Vectorise.Utils.Hoisting
10 import Vectorise.Exp
11 import Vectorise.Vect
12 import Vectorise.Env
13 import Vectorise.Monad
14
15 import HscTypes hiding      ( MonadThings(..) )
16 import CoreUnfold           ( mkInlineUnfolding )
17 import CoreFVs
18 import PprCore
19 import CoreSyn
20 import CoreMonad            ( CoreM, getHscEnv )
21 import Type
22 import Var
23 import Id
24 import OccName
25 import DynFlags
26 import BasicTypes           ( isLoopBreaker )
27 import Outputable
28 import Util                 ( zipLazy )
29 import MonadUtils
30
31 import Control.Monad
32
33
34 -- | Vectorise a single module.
35 --
36 vectorise :: ModGuts -> CoreM ModGuts
37 vectorise guts
38  = do { hsc_env <- getHscEnv
39       ; liftIO $ vectoriseIO hsc_env guts
40       }
41
42 -- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
43 --
44 vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
45 vectoriseIO hsc_env guts
46  = do {   -- Get information about currently loaded external packages.
47       ; eps <- hscEPS hsc_env
48
49           -- Combine vectorisation info from the current module, and external ones.
50       ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
51
52           -- Run the main VM computation.
53       ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
54       ; return (guts' { mg_vect_info = info' })
55       }
56
57 -- | Vectorise a single module, in the VM monad.
58 --
59 vectModule :: ModGuts -> VM ModGuts
60 vectModule guts@(ModGuts { mg_types     = types
61                          , mg_binds     = binds
62                          , mg_fam_insts = fam_insts
63                          })
64  = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
65           pprCoreBindings binds
66  
67           -- Vectorise the type environment.
68           -- This may add new TyCons and DataCons.
69       ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types
70
71       ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
72
73       -- dicts   <- mapM buildPADict pa_insts
74       -- workers <- mapM vectDataConWorkers pa_insts
75
76           -- Vectorise all the top level bindings.
77       ; binds'  <- mapM vectTopBind binds
78
79       ; return $ guts { mg_types        = types'
80                       , mg_binds        = Rec tc_binds : binds'
81                       , mg_fam_inst_env = fam_inst_env
82                       , mg_fam_insts    = fam_insts ++ new_fam_insts
83                       }
84       }
85
86 -- | Try to vectorise a top-level binding.
87 --   If it doesn't vectorise then return it unharmed.
88 --
89 --   For example, for the binding 
90 --
91 --   @  
92 --      foo :: Int -> Int
93 --      foo = \x -> x + x
94 --   @
95 --  
96 --   we get
97 --   @
98 --      foo  :: Int -> Int
99 --      foo  = \x -> vfoo $: x                  
100 -- 
101 --      v_foo :: Closure void vfoo lfoo
102 --      v_foo = closure vfoo lfoo void        
103 -- 
104 --      vfoo :: Void -> Int -> Int
105 --      vfoo = ...
106 --
107 --      lfoo :: PData Void -> PData Int -> PData Int
108 --      lfoo = ...
109 --   @ 
110 --
111 --   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
112 --   function foo, but takes an explicit environment.
113 -- 
114 --   @lfoo@ is the "lifted" version that works on arrays.
115 --
116 --   @v_foo@ combines both of these into a `Closure` that also contains the
117 --   environment.
118 --
119 --   The original binding @foo@ is rewritten to call the vectorised version
120 --   present in the closure.
121 --
122 vectTopBind :: CoreBind -> VM CoreBind
123 vectTopBind b@(NonRec var expr)
124  = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it to
125           -- the vectorisation map.
126       ; (inline, isScalar, expr') <- vectTopRhs [] var expr
127       ; var' <- vectTopBinder var inline expr'
128       ; when isScalar $ 
129           addGlobalScalar var
130
131           -- We replace the original top-level binding by a value projected from the vectorised
132           -- closure and add any newly created hoisted top-level bindings.
133       ; cexpr <- tryConvert var var' expr
134       ; hs <- takeHoisted
135       ; return . Rec $ (var, cexpr) : (var', expr') : hs
136       }
137   `orElseV`
138     return b
139 vectTopBind b@(Rec bs)
140  = let (vars, exprs) = unzip bs
141    in
142    do { (vars', _, exprs', hs) <- fixV $ 
143           \ ~(_, inlines, rhss, _) ->
144             do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings and
145                    --  add them to the vectorisation map.
146                ; vars' <- sequence [vectTopBinder var inline rhs
147                                    | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
148                ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
149                ; hs <- takeHoisted
150                ; if and areScalars
151                  then      -- (1) Entire recursive group is scalar
152                            --      => add all variables to the global set of scalars
153                       do { mapM addGlobalScalar vars
154                          ; return (vars', inlines, exprs', hs)
155                          }
156                  else      -- (2) At least one binding is not scalar
157                            --     => vectorise again with empty set of local scalars
158                       do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
159                          ; hs <- takeHoisted
160                          ; return (vars', inlines, exprs', hs)
161                          }
162                }
163                       
164           -- Replace the original top-level bindings by a values projected from the vectorised
165           -- closures and add any newly created hoisted top-level bindings to the group.
166       ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
167       ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
168       }
169   `orElseV`
170     return b    
171     
172 -- | Make the vectorised version of this top level binder, and add the mapping
173 --   between it and the original to the state. For some binder @foo@ the vectorised
174 --   version is @$v_foo@
175 --
176 --   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
177 --   used inside of fixV in vectTopBind
178 --
179 vectTopBinder :: Var      -- ^ Name of the binding.
180               -> Inline   -- ^ Whether it should be inlined, used to annotate it.
181               -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
182               -> VM Var   -- ^ Name of the vectorised binding.
183 vectTopBinder var inline expr
184  = do {   -- Vectorise the type attached to the var.
185       ; vty  <- vectType (idType var)
186       
187           -- If there is a vectorisation declartion for this binding, make sure that its type
188           --  matches
189       ; vectDecl <- lookupVectDecl var
190       ; case vectDecl of
191           Nothing                 -> return ()
192           Just (vdty, _) 
193             | coreEqType vty vdty -> return ()
194             | otherwise           -> 
195               cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
196                 (text "Expected type" <+> ppr vty)
197                 $$
198                 (text "Inferred type" <+> ppr vdty)
199
200           -- Make the vectorised version of binding's name, and set the unfolding used for inlining
201       ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
202                 $  cloneId mkVectOcc var vty
203
204           -- Add the mapping between the plain and vectorised name to the state.
205       ; defGlobalVar var var'
206
207       ; return var'
208     }
209   where
210     unfolding = case inline of
211                   Inline arity -> mkInlineUnfolding (Just arity) expr
212                   DontInline   -> noUnfolding
213
214 -- | Vectorise the RHS of a top-level binding, in an empty local environment.
215 --
216 -- We need to distinguish three cases:
217 --
218 -- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
219 --     vectorised code implemented by the user)
220 --     => no automatic vectorisation & instead use the user-supplied code
221 -- 
222 -- (2) We have a scalar vectorisation declaration for the variable
223 --     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
224 -- 
225 -- (3) There is no vectorisation declaration for the variable
226 --     => perform automatic vectorisation of the RHS
227 --
228 vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
229            -> Var             -- ^ Name of the binding.
230            -> CoreExpr        -- ^ Body of the binding.
231            -> VM ( Inline     -- (1) inline specification for the binding
232                  , Bool       -- (2) whether the right-hand side is a scalar computation
233                  , CoreExpr)  -- (3) the vectorised right-hand side
234 vectTopRhs recFs var expr
235   = closedV
236   $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
237   
238        ; globalScalar <- isGlobalScalar var
239        ; vectDecl     <- lookupVectDecl var
240        ; rhs globalScalar vectDecl
241        }
242   where
243     rhs _globalScalar (Just (_, expr'))               -- Case (1)
244       = return (inlineMe, False, expr')
245     rhs True          Nothing                         -- Case (2)
246       = do { expr' <- vectScalarFun True recFs expr
247            ; return (inlineMe, True, vectorised expr')
248            }
249     rhs False         Nothing                         -- Case (3)
250       = do { let fvs = freeVars expr
251            ; (inline, isScalar, vexpr) <- inBind var $
252                                             vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs fvs
253            ; return (inline, isScalar, vectorised vexpr)
254            }
255
256 -- | Project out the vectorised version of a binding from some closure,
257 --   or return the original body if that doesn't work or the binding is scalar. 
258 --
259 tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
260            -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
261            -> CoreExpr  -- ^ The original body of the binding.
262            -> VM CoreExpr
263 tryConvert var vect_var rhs
264   = do { globalScalar <- isGlobalScalar var
265        ; if globalScalar
266          then
267            return rhs
268          else
269            fromVect (idType var) (Var vect_var) `orElseV` return rhs
270        }