b00206d6ad98da3734e117fc555fb9d1c92d4f9c
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPRepr, mkToPRepr, mkFromPRepr,
7   mkPADictType, mkPArrayType, mkPReprType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   prDictOfType,
10   paDictArgType, paDictOfType, paDFunType,
11   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12   polyAbstract, polyApply, polyVApply,
13   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14   buildClosure, buildClosures,
15   mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Type
27 import TypeRep
28 import TyCon
29 import DataCon            ( DataCon, dataConWrapId, dataConTag )
30 import Var
31 import Id                 ( mkWildId )
32 import MkId               ( unwrapFamInstScrut )
33 import Name               ( Name )
34 import PrelNames
35 import TysWiredIn
36 import BasicTypes         ( Boxity(..) )
37
38 import Outputable
39 import FastString
40
41 import Control.Monad         ( liftM, liftM2, zipWithM_ )
42
43 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
44 collectAnnTypeArgs expr = go expr []
45   where
46     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
47     go e                             tys = (e, tys)
48
49 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
50 collectAnnTypeBinders expr = go [] expr
51   where
52     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
53     go bs e                           = (reverse bs, e)
54
55 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnValBinders expr = go [] expr
57   where
58     go bs (_, AnnLam b e) | isId b = go (b:bs) e
59     go bs e                        = (reverse bs, e)
60
61 isAnnTypeArg :: AnnExpr b ann -> Bool
62 isAnnTypeArg (_, AnnType t) = True
63 isAnnTypeArg _              = False
64
65 mkDataConTag :: DataCon -> CoreExpr
66 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
67
68 splitUnTy :: String -> Name -> Type -> Type
69 splitUnTy s name ty
70   | Just (tc, [ty']) <- splitTyConApp_maybe ty
71   , tyConName tc == name
72   = ty'
73
74   | otherwise = pprPanic s (ppr ty)
75
76 splitBinTy :: String -> Name -> Type -> (Type, Type)
77 splitBinTy s name ty
78   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
79   , tyConName tc == name
80   = (ty1, ty2)
81
82   | otherwise = pprPanic s (ppr ty)
83
84 splitCrossTy :: Type -> (Type, Type)
85 splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
86
87 splitPlusTy :: Type -> (Type, Type)
88 splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
89
90 splitEmbedTy :: Type -> Type
91 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
92
93 splitClosureTy :: Type -> (Type, Type)
94 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
95
96 splitPArrayTy :: Type -> Type
97 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
98
99 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
100 mkBuiltinTyConApp get_tc tys
101   = do
102       tc <- builtin get_tc
103       return $ mkTyConApp tc tys
104
105 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
106 mkBuiltinTyConApps get_tc tys ty
107   = do
108       tc <- builtin get_tc
109       return $ foldr (mk tc) ty tys
110   where
111     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
112
113 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
114 mkBuiltinTyConApps1 get_tc dft [] = return dft
115 mkBuiltinTyConApps1 get_tc dft tys
116   = do
117       tc <- builtin get_tc
118       case tys of
119         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
120         _  -> return $ foldr1 (mk tc) tys
121   where
122     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
123
124 mkPRepr :: [[Type]] -> VM Type
125 mkPRepr [] = return unitTy
126 mkPRepr tys
127   = do
128       embed <- builtin embedTyCon
129       cross <- builtin crossTyCon
130       plus  <- builtin plusTyCon
131
132       let mk_embed ty      = mkTyConApp embed [ty]
133           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
134           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
135
136           mk_tup   []      = unitTy
137           mk_tup   tys     = foldr1 mk_cross tys
138
139           mk_sum   []      = unitTy
140           mk_sum   tys     = foldr1 mk_plus  tys
141
142       return . mk_sum
143              . map (mk_tup . map mk_embed)
144              $ tys
145
146 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
147 mkToPRepr ess
148   = do
149       embed_tc <- builtin embedTyCon
150       embed_dc <- builtin embedDataCon
151       cross_tc <- builtin crossTyCon
152       cross_dc <- builtin crossDataCon
153       plus_tc  <- builtin plusTyCon
154       left_dc  <- builtin leftDataCon
155       right_dc <- builtin rightDataCon
156
157       let mk_embed expr
158             = (mkConApp   embed_dc [Type ty, expr],
159                mkTyConApp embed_tc [ty])
160             where ty = exprType expr
161
162           mk_cross (expr1, ty1) (expr2, ty2)
163             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
164                mkTyConApp cross_tc [ty1, ty2])
165
166           mk_tup [] = (Var unitDataConId, unitTy)
167           mk_tup es = foldr1 mk_cross es
168
169           mk_sum []           = ([Var unitDataConId], unitTy)
170           mk_sum [(expr, ty)] = ([expr], ty)
171           mk_sum ((expr, lty) : es)
172             = let (alts, rty) = mk_sum es
173               in
174               (mkConApp left_dc [Type lty, Type rty, expr]
175                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
176                mkTyConApp plus_tc [lty, rty])
177
178       return . mk_sum $ map (mk_tup . map mk_embed) ess
179
180 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
181 mkFromPRepr scrut res_ty alts
182   = do
183       embed_dc <- builtin embedDataCon
184       cross_dc <- builtin crossDataCon
185       left_dc  <- builtin leftDataCon
186       right_dc <- builtin rightDataCon
187       pa_tc    <- builtin paTyCon
188
189       let un_embed expr ty var res
190             = Case expr (mkWildId ty) res_ty
191                    [(DataAlt embed_dc, [var], res)]
192
193           un_cross expr ty var1 var2 res
194             = Case expr (mkWildId ty) res_ty
195                 [(DataAlt cross_dc, [var1, var2], res)]
196
197           un_tup expr ty []    res = return res
198           un_tup expr ty [var] res = return $ un_embed expr ty var res
199           un_tup expr ty (var : vars) res
200             = do
201                 lv <- newLocalVar FSLIT("x") lty
202                 rv <- newLocalVar FSLIT("y") rty
203                 liftM (un_cross expr ty lv rv
204                       . un_embed (Var lv) lty var)
205                       (un_tup (Var rv) rty vars res)
206             where
207               (lty, rty) = splitCrossTy ty
208
209           un_plus expr ty var1 var2 res1 res2
210             = Case expr (mkWildId ty) res_ty
211                 [(DataAlt left_dc,  [var1], res1),
212                  (DataAlt right_dc, [var2], res2)]
213
214           un_sum expr ty [(vars, res)] = un_tup expr ty vars res
215           un_sum expr ty ((vars, res) : alts)
216             = do
217                 lv <- newLocalVar FSLIT("l") lty
218                 rv <- newLocalVar FSLIT("r") rty
219                 liftM2 (un_plus expr ty lv rv)
220                          (un_tup (Var lv) lty vars res)
221                          (un_sum (Var rv) rty alts)
222             where
223               (lty, rty) = splitPlusTy ty
224
225       un_sum scrut (exprType scrut) alts
226
227 mkClosureType :: Type -> Type -> VM Type
228 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
229
230 mkClosureTypes :: [Type] -> Type -> VM Type
231 mkClosureTypes = mkBuiltinTyConApps closureTyCon
232
233 mkPReprType :: Type -> VM Type
234 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
235
236 mkPADictType :: Type -> VM Type
237 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
238
239 mkPArrayType :: Type -> VM Type
240 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
241
242 parrayReprTyCon :: Type -> VM (TyCon, [Type])
243 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
244
245 parrayReprDataCon :: Type -> VM (DataCon, [Type])
246 parrayReprDataCon ty
247   = do
248       (tc, arg_tys) <- parrayReprTyCon ty
249       let [dc] = tyConDataCons tc
250       return (dc, arg_tys)
251
252 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
253 mkVScrut (ve, le)
254   = do
255       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
256       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
257
258
259 prDictOfType :: Type -> VM CoreExpr
260 prDictOfType orig_ty
261   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
262   = do
263       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
264       prDFunApply (Var dfun) ty_args
265
266 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
267 prDFunApply dfun tys
268   = do
269       args <- mapM mkDFunArg arg_tys
270       return $ mkApps mono_dfun args
271   where
272     mono_dfun    = mkTyApps dfun tys
273     (arg_tys, _) = splitFunTys (exprType mono_dfun)
274
275 mkDFunArg :: Type -> VM CoreExpr
276 mkDFunArg ty
277   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
278
279   = let name = tyConName tycon
280
281         get_dict | name == paTyConName = paDictOfType
282                  | name == prTyConName = prDictOfType
283                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
284
285     in get_dict arg
286
287 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
288
289 paDictArgType :: TyVar -> VM (Maybe Type)
290 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
291   where
292     go ty k | Just k' <- kindView k = go ty k'
293     go ty (FunTy k1 k2)
294       = do
295           tv   <- newTyVar FSLIT("a") k1
296           mty1 <- go (TyVarTy tv) k1
297           case mty1 of
298             Just ty1 -> do
299                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
300                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
301             Nothing  -> go ty k2
302
303     go ty k
304       | isLiftedTypeKind k
305       = liftM Just (mkPADictType ty)
306
307     go ty k = return Nothing
308
309 paDictOfType :: Type -> VM CoreExpr
310 paDictOfType ty = paDictOfTyApp ty_fn ty_args
311   where
312     (ty_fn, ty_args) = splitAppTys ty
313
314 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
315 paDictOfTyApp ty_fn ty_args
316   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
317 paDictOfTyApp (TyVarTy tv) ty_args
318   = do
319       dfun <- maybeV (lookupTyVarPA tv)
320       paDFunApply dfun ty_args
321 paDictOfTyApp (TyConApp tc _) ty_args
322   = do
323       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
324       paDFunApply (Var dfun) ty_args
325 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
326
327 paDFunType :: TyCon -> VM Type
328 paDFunType tc
329   = do
330       margs <- mapM paDictArgType tvs
331       res   <- mkPADictType (mkTyConApp tc arg_tys)
332       return . mkForAllTys tvs
333              $ mkFunTys [arg | Just arg <- margs] res
334   where
335     tvs = tyConTyVars tc
336     arg_tys = mkTyVarTys tvs
337
338 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
339 paDFunApply dfun tys
340   = do
341       dicts <- mapM paDictOfType tys
342       return $ mkApps (mkTyApps dfun tys) dicts
343
344 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
345 paMethod method ty
346   = do
347       fn   <- builtin method
348       dict <- paDictOfType ty
349       return $ mkApps (Var fn) [Type ty, dict]
350
351 lengthPA :: CoreExpr -> VM CoreExpr
352 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
353   where
354     ty = splitPArrayTy (exprType x)
355
356 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
357 replicatePA len x = liftM (`mkApps` [len,x])
358                           (paMethod replicatePAVar (exprType x))
359
360 emptyPA :: Type -> VM CoreExpr
361 emptyPA = paMethod emptyPAVar
362
363 liftPA :: CoreExpr -> VM CoreExpr
364 liftPA x
365   = do
366       lc <- builtin liftingContext
367       replicatePA (Var lc) x
368
369 newLocalVVar :: FastString -> Type -> VM VVar
370 newLocalVVar fs vty
371   = do
372       lty <- mkPArrayType vty
373       vv  <- newLocalVar fs vty
374       lv  <- newLocalVar fs lty
375       return (vv,lv)
376
377 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
378 polyAbstract tvs p
379   = localV
380   $ do
381       mdicts <- mapM mk_dict_var tvs
382       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
383       p (mk_lams mdicts)
384   where
385     mk_dict_var tv = do
386                        r <- paDictArgType tv
387                        case r of
388                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
389                          Nothing -> return Nothing
390
391     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
392
393 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
394 polyApply expr tys
395   = do
396       dicts <- mapM paDictOfType tys
397       return $ expr `mkTyApps` tys `mkApps` dicts
398
399 polyVApply :: VExpr -> [Type] -> VM VExpr
400 polyVApply expr tys
401   = do
402       dicts <- mapM paDictOfType tys
403       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
404
405 hoistBinding :: Var -> CoreExpr -> VM ()
406 hoistBinding v e = updGEnv $ \env ->
407   env { global_bindings = (v,e) : global_bindings env }
408
409 hoistExpr :: FastString -> CoreExpr -> VM Var
410 hoistExpr fs expr
411   = do
412       var <- newLocalVar fs (exprType expr)
413       hoistBinding var expr
414       return var
415
416 hoistVExpr :: VExpr -> VM VVar
417 hoistVExpr (ve, le)
418   = do
419       fs <- getBindName
420       vv <- hoistExpr ('v' `consFS` fs) ve
421       lv <- hoistExpr ('l' `consFS` fs) le
422       return (vv, lv)
423
424 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
425 hoistPolyVExpr tvs p
426   = do
427       expr <- closedV . polyAbstract tvs $ \abstract ->
428               liftM (mapVect abstract) p
429       fn   <- hoistVExpr expr
430       polyVApply (vVar fn) (mkTyVarTys tvs)
431
432 takeHoisted :: VM [(Var, CoreExpr)]
433 takeHoisted
434   = do
435       env <- readGEnv id
436       setGEnv $ env { global_bindings = [] }
437       return $ global_bindings env
438
439 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
440 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
441   = do
442       dict <- paDictOfType env_ty
443       mkv  <- builtin mkClosureVar
444       mkl  <- builtin mkClosurePVar
445       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
446               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
447
448 mkClosureApp :: VExpr -> VExpr -> VM VExpr
449 mkClosureApp (vclo, lclo) (varg, larg)
450   = do
451       vapply <- builtin applyClosureVar
452       lapply <- builtin applyClosurePVar
453       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
454               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
455   where
456     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
457
458 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
459 buildClosures tvs vars [] res_ty mk_body
460   = mk_body
461 buildClosures tvs vars [arg_ty] res_ty mk_body
462   = buildClosure tvs vars arg_ty res_ty mk_body
463 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
464   = do
465       res_ty' <- mkClosureTypes arg_tys res_ty
466       arg <- newLocalVVar FSLIT("x") arg_ty
467       buildClosure tvs vars arg_ty res_ty'
468         . hoistPolyVExpr tvs
469         $ do
470             lc <- builtin liftingContext
471             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
472             return $ vLams lc (vars ++ [arg]) clo
473
474 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
475 --   where
476 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
477 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
478 --
479 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
480 buildClosure tvs vars arg_ty res_ty mk_body
481   = do
482       (env_ty, env, bind) <- buildEnv vars
483       env_bndr <- newLocalVVar FSLIT("env") env_ty
484       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
485
486       fn <- hoistPolyVExpr tvs
487           $ do
488               lc    <- builtin liftingContext
489               body  <- mk_body
490               body' <- bind (vVar env_bndr)
491                             (vVarApps lc body (vars ++ [arg_bndr]))
492               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
493
494       mkClosure arg_ty res_ty env_ty fn env
495
496 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
497 buildEnv vvs
498   = do
499       lc <- builtin liftingContext
500       let (ty, venv, vbind) = mkVectEnv tys vs
501       (lenv, lbind) <- mkLiftEnv lc tys ls
502       return (ty, (venv, lenv),
503               \(venv,lenv) (vbody,lbody) ->
504               do
505                 let vbody' = vbind venv vbody
506                 lbody' <- lbind lenv lbody
507                 return (vbody', lbody'))
508   where
509     (vs,ls) = unzip vvs
510     tys     = map idType vs
511
512 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
513 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
514 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
515 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
516                         \env body -> Case env (mkWildId ty) (exprType body)
517                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
518   where
519     ty = mkCoreTupTy tys
520
521 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
522 mkLiftEnv lc [ty] [v]
523   = return (Var v, \env body ->
524                    do
525                      len <- lengthPA (Var v)
526                      return . Let (NonRec v env)
527                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
528
529 -- NOTE: this transparently deals with empty environments
530 mkLiftEnv lc tys vs
531   = do
532       (env_tc, env_tyargs) <- parrayReprTyCon vty
533       let [env_con] = tyConDataCons env_tc
534           
535           env = Var (dataConWrapId env_con)
536                 `mkTyApps`  env_tyargs
537                 `mkVarApps` (lc : vs)
538
539           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
540                           in
541                           return $ Case scrut (mkWildId (exprType scrut))
542                                         (exprType body)
543                                         [(DataAlt env_con, lc : bndrs, body)]
544       return (env, bind)
545   where
546     vty = mkCoreTupTy tys
547
548     bndrs | null vs   = [mkWildId unitTy]
549           | otherwise = vs
550