Fix bugs in vectorisation of case expressions
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10   collectAnnValBinders,
11   dataConTagZ, mkDataConTag, mkDataConTagLit,
12
13   newLocalVVar,
14
15   mkBuiltinCo,
16   mkPADictType, mkPArrayType, mkPReprType,
17
18   parrayReprTyCon, parrayReprDataCon, mkVScrut,
19   prDFunOfTyCon,
20   paDictArgType, paDictOfType, paDFunType,
21   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
22   polyAbstract, polyApply, polyVApply,
23   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
24   buildClosure, buildClosures,
25   mkClosureApp
26 ) where
27
28 #include "HsVersions.h"
29
30 import VectCore
31 import VectMonad
32
33 import DsUtils
34 import CoreSyn
35 import CoreUtils
36 import Coercion
37 import Type
38 import TypeRep
39 import TyCon
40 import DataCon
41 import Var
42 import Id                 ( mkWildId )
43 import MkId               ( unwrapFamInstScrut )
44 import Name               ( Name )
45 import PrelNames
46 import TysWiredIn
47 import TysPrim            ( intPrimTy )
48 import BasicTypes         ( Boxity(..) )
49 import Literal            ( Literal, mkMachInt )
50
51 import Outputable
52 import FastString
53
54 import Data.List             ( zipWith4 )
55 import Control.Monad         ( liftM, liftM2, zipWithM_ )
56
57 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
58 collectAnnTypeArgs expr = go expr []
59   where
60     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
61     go e                             tys = (e, tys)
62
63 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
64 collectAnnTypeBinders expr = go [] expr
65   where
66     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
67     go bs e                           = (reverse bs, e)
68
69 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
70 collectAnnValBinders expr = go [] expr
71   where
72     go bs (_, AnnLam b e) | isId b = go (b:bs) e
73     go bs e                        = (reverse bs, e)
74
75 isAnnTypeArg :: AnnExpr b ann -> Bool
76 isAnnTypeArg (_, AnnType t) = True
77 isAnnTypeArg _              = False
78
79 dataConTagZ :: DataCon -> Int
80 dataConTagZ con = dataConTag con - fIRST_TAG
81
82 mkDataConTagLit :: DataCon -> Literal
83 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
84
85 mkDataConTag :: DataCon -> CoreExpr
86 mkDataConTag = mkIntLitInt . dataConTagZ
87
88 splitPrimTyCon :: Type -> Maybe TyCon
89 splitPrimTyCon ty
90   | Just (tycon, []) <- splitTyConApp_maybe ty
91   , isPrimTyCon tycon
92   = Just tycon
93
94   | otherwise = Nothing
95
96 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
97 mkBuiltinTyConApp get_tc tys
98   = do
99       tc <- builtin get_tc
100       return $ mkTyConApp tc tys
101
102 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
103 mkBuiltinTyConApps get_tc tys ty
104   = do
105       tc <- builtin get_tc
106       return $ foldr (mk tc) ty tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
111 mkBuiltinTyConApps1 get_tc dft [] = return dft
112 mkBuiltinTyConApps1 get_tc dft tys
113   = do
114       tc <- builtin get_tc
115       case tys of
116         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
117         _  -> return $ foldr1 (mk tc) tys
118   where
119     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
120
121 mkClosureType :: Type -> Type -> VM Type
122 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
123
124 mkClosureTypes :: [Type] -> Type -> VM Type
125 mkClosureTypes = mkBuiltinTyConApps closureTyCon
126
127 mkPReprType :: Type -> VM Type
128 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
129
130 mkPADictType :: Type -> VM Type
131 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
132
133 mkPArrayType :: Type -> VM Type
134 mkPArrayType ty
135   | Just tycon <- splitPrimTyCon ty
136   = do
137       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
138            $ lookupPrimPArray tycon
139       return $ mkTyConApp arr []
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
143 mkBuiltinCo get_tc
144   = do
145       tc <- builtin get_tc
146       return $ mkTyConApp tc []
147
148 parrayReprTyCon :: Type -> VM (TyCon, [Type])
149 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
150
151 parrayReprDataCon :: Type -> VM (DataCon, [Type])
152 parrayReprDataCon ty
153   = do
154       (tc, arg_tys) <- parrayReprTyCon ty
155       let [dc] = tyConDataCons tc
156       return (dc, arg_tys)
157
158 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
159 mkVScrut (ve, le)
160   = do
161       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
162       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
163
164 prDFunOfTyCon :: TyCon -> VM CoreExpr
165 prDFunOfTyCon tycon
166   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170   where
171     go ty k | Just k' <- kindView k = go ty k'
172     go ty (FunTy k1 k2)
173       = do
174           tv   <- newTyVar FSLIT("a") k1
175           mty1 <- go (TyVarTy tv) k1
176           case mty1 of
177             Just ty1 -> do
178                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
179                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
180             Nothing  -> go ty k2
181
182     go ty k
183       | isLiftedTypeKind k
184       = liftM Just (mkPADictType ty)
185
186     go ty k = return Nothing
187
188 paDictOfType :: Type -> VM CoreExpr
189 paDictOfType ty = paDictOfTyApp ty_fn ty_args
190   where
191     (ty_fn, ty_args) = splitAppTys ty
192
193 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
194 paDictOfTyApp ty_fn ty_args
195   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
196 paDictOfTyApp (TyVarTy tv) ty_args
197   = do
198       dfun <- maybeV (lookupTyVarPA tv)
199       paDFunApply dfun ty_args
200 paDictOfTyApp (TyConApp tc _) ty_args
201   = do
202       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
203       paDFunApply (Var dfun) ty_args
204 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
205
206 paDFunType :: TyCon -> VM Type
207 paDFunType tc
208   = do
209       margs <- mapM paDictArgType tvs
210       res   <- mkPADictType (mkTyConApp tc arg_tys)
211       return . mkForAllTys tvs
212              $ mkFunTys [arg | Just arg <- margs] res
213   where
214     tvs = tyConTyVars tc
215     arg_tys = mkTyVarTys tvs
216
217 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
218 paDFunApply dfun tys
219   = do
220       dicts <- mapM paDictOfType tys
221       return $ mkApps (mkTyApps dfun tys) dicts
222
223 type PAMethod = (Builtins -> Var, String)
224
225 pa_length    = (lengthPAVar,    "lengthPA")
226 pa_replicate = (replicatePAVar, "replicatePA")
227 pa_empty     = (emptyPAVar,     "emptyPA")
228 pa_pack      = (packPAVar,      "packPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232   | Just tycon <- splitPrimTyCon ty
233   = do
234       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
235           $ lookupPrimMethod tycon name
236       return (Var fn)
237
238 paMethod (method, name) ty
239   = do
240       fn   <- builtin method
241       dict <- paDictOfType ty
242       return $ mkApps (Var fn) [Type ty, dict]
243
244 mkPR :: Type -> VM CoreExpr
245 mkPR ty
246   = do
247       fn   <- builtin mkPRVar
248       dict <- paDictOfType ty
249       return $ mkApps (Var fn) [Type ty, dict]
250
251 lengthPA :: Type -> CoreExpr -> VM CoreExpr
252 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256                           (paMethod pa_replicate (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod pa_empty
260
261 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
262 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
263                              (paMethod pa_pack ty)
264
265 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
266           -> VM CoreExpr
267 combinePA ty len sel is xs
268   = liftM (`mkApps` (len : sel : is : xs))
269           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
270   where
271     n = length xs
272
273 liftPA :: CoreExpr -> VM CoreExpr
274 liftPA x
275   = do
276       lc <- builtin liftingContext
277       replicatePA (Var lc) x
278
279 newLocalVVar :: FastString -> Type -> VM VVar
280 newLocalVVar fs vty
281   = do
282       lty <- mkPArrayType vty
283       vv  <- newLocalVar fs vty
284       lv  <- newLocalVar fs lty
285       return (vv,lv)
286
287 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
288 polyAbstract tvs p
289   = localV
290   $ do
291       mdicts <- mapM mk_dict_var tvs
292       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
293       p (mk_lams mdicts)
294   where
295     mk_dict_var tv = do
296                        r <- paDictArgType tv
297                        case r of
298                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
299                          Nothing -> return Nothing
300
301     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
302
303 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
304 polyApply expr tys
305   = do
306       dicts <- mapM paDictOfType tys
307       return $ expr `mkTyApps` tys `mkApps` dicts
308
309 polyVApply :: VExpr -> [Type] -> VM VExpr
310 polyVApply expr tys
311   = do
312       dicts <- mapM paDictOfType tys
313       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
314
315 hoistBinding :: Var -> CoreExpr -> VM ()
316 hoistBinding v e = updGEnv $ \env ->
317   env { global_bindings = (v,e) : global_bindings env }
318
319 hoistExpr :: FastString -> CoreExpr -> VM Var
320 hoistExpr fs expr
321   = do
322       var <- newLocalVar fs (exprType expr)
323       hoistBinding var expr
324       return var
325
326 hoistVExpr :: VExpr -> VM VVar
327 hoistVExpr (ve, le)
328   = do
329       fs <- getBindName
330       vv <- hoistExpr ('v' `consFS` fs) ve
331       lv <- hoistExpr ('l' `consFS` fs) le
332       return (vv, lv)
333
334 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
335 hoistPolyVExpr tvs p
336   = do
337       expr <- closedV . polyAbstract tvs $ \abstract ->
338               liftM (mapVect abstract) p
339       fn   <- hoistVExpr expr
340       polyVApply (vVar fn) (mkTyVarTys tvs)
341
342 takeHoisted :: VM [(Var, CoreExpr)]
343 takeHoisted
344   = do
345       env <- readGEnv id
346       setGEnv $ env { global_bindings = [] }
347       return $ global_bindings env
348
349 boxExpr :: Type -> VExpr -> VM VExpr
350 boxExpr ty (vexpr, lexpr)
351   | Just (tycon, []) <- splitTyConApp_maybe ty
352   , isUnLiftedTyCon tycon
353   = do
354       r <- lookupBoxedTyCon tycon
355       case r of
356         Just tycon' -> let [dc] = tyConDataCons tycon'
357                        in
358                        return (mkConApp dc [vexpr], lexpr)
359         Nothing     -> return (vexpr, lexpr)
360
361
362 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
363 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
364   = do
365       dict <- paDictOfType env_ty
366       mkv  <- builtin mkClosureVar
367       mkl  <- builtin mkClosurePVar
368       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
369               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
370
371 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
372 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
373   = do
374       vapply <- builtin applyClosureVar
375       lapply <- builtin applyClosurePVar
376       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
377               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
378
379 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
380 buildClosures tvs vars [] res_ty mk_body
381   = mk_body
382 buildClosures tvs vars [arg_ty] res_ty mk_body
383   = buildClosure tvs vars arg_ty res_ty mk_body
384 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
385   = do
386       res_ty' <- mkClosureTypes arg_tys res_ty
387       arg <- newLocalVVar FSLIT("x") arg_ty
388       buildClosure tvs vars arg_ty res_ty'
389         . hoistPolyVExpr tvs
390         $ do
391             lc <- builtin liftingContext
392             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
393             return $ vLams lc (vars ++ [arg]) clo
394
395 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
396 --   where
397 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
398 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
399 --
400 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
401 buildClosure tvs vars arg_ty res_ty mk_body
402   = do
403       (env_ty, env, bind) <- buildEnv vars
404       env_bndr <- newLocalVVar FSLIT("env") env_ty
405       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
406
407       fn <- hoistPolyVExpr tvs
408           $ do
409               lc    <- builtin liftingContext
410               body  <- mk_body
411               body' <- bind (vVar env_bndr)
412                             (vVarApps lc body (vars ++ [arg_bndr]))
413               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
414
415       mkClosure arg_ty res_ty env_ty fn env
416
417 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
418 buildEnv vvs
419   = do
420       lc <- builtin liftingContext
421       let (ty, venv, vbind) = mkVectEnv tys vs
422       (lenv, lbind) <- mkLiftEnv lc tys ls
423       return (ty, (venv, lenv),
424               \(venv,lenv) (vbody,lbody) ->
425               do
426                 let vbody' = vbind venv vbody
427                 lbody' <- lbind lenv lbody
428                 return (vbody', lbody'))
429   where
430     (vs,ls) = unzip vvs
431     tys     = map idType vs
432
433 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
434 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
435 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
436 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
437                         \env body -> Case env (mkWildId ty) (exprType body)
438                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
439   where
440     ty = mkCoreTupTy tys
441
442 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
443 mkLiftEnv lc [ty] [v]
444   = return (Var v, \env body ->
445                    do
446                      len <- lengthPA ty (Var v)
447                      return . Let (NonRec v env)
448                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
449
450 -- NOTE: this transparently deals with empty environments
451 mkLiftEnv lc tys vs
452   = do
453       (env_tc, env_tyargs) <- parrayReprTyCon vty
454       let [env_con] = tyConDataCons env_tc
455           
456           env = Var (dataConWrapId env_con)
457                 `mkTyApps`  env_tyargs
458                 `mkApps`    (Var lc : args)
459
460           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
461                           in
462                           return $ Case scrut (mkWildId (exprType scrut))
463                                         (exprType body)
464                                         [(DataAlt env_con, lc : bndrs, body)]
465       return (env, bind)
466   where
467     vty = mkCoreTupTy tys
468
469     args  | null vs   = [Var unitDataConId]
470           | otherwise = map Var vs
471
472     bndrs | null vs   = [mkWildId unitTy]
473           | otherwise = vs
474