Embed doesn't store a PA dictionary any more
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPRepr, mkToPRepr, mkFromPRepr,
7   mkPADictType, mkPArrayType, mkPReprType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   paDictArgType, paDictOfType, paDFunType,
10   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
11   polyAbstract, polyApply, polyVApply,
12   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
13   buildClosure, buildClosures,
14   mkClosureApp
15 ) where
16
17 #include "HsVersions.h"
18
19 import VectCore
20 import VectMonad
21
22 import DsUtils
23 import CoreSyn
24 import CoreUtils
25 import Type
26 import TypeRep
27 import TyCon
28 import DataCon            ( DataCon, dataConWrapId, dataConTag )
29 import Var
30 import Id                 ( mkWildId )
31 import MkId               ( unwrapFamInstScrut )
32 import Name               ( Name )
33 import PrelNames
34 import TysWiredIn
35 import BasicTypes         ( Boxity(..) )
36
37 import Outputable
38 import FastString
39
40 import Control.Monad         ( liftM, liftM2, zipWithM_ )
41
42 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
43 collectAnnTypeArgs expr = go expr []
44   where
45     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46     go e                             tys = (e, tys)
47
48 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
49 collectAnnTypeBinders expr = go [] expr
50   where
51     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
52     go bs e                           = (reverse bs, e)
53
54 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnValBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isId b = go (b:bs) e
58     go bs e                        = (reverse bs, e)
59
60 isAnnTypeArg :: AnnExpr b ann -> Bool
61 isAnnTypeArg (_, AnnType t) = True
62 isAnnTypeArg _              = False
63
64 mkDataConTag :: DataCon -> CoreExpr
65 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
66
67 splitUnTy :: String -> Name -> Type -> Type
68 splitUnTy s name ty
69   | Just (tc, [ty']) <- splitTyConApp_maybe ty
70   , tyConName tc == name
71   = ty'
72
73   | otherwise = pprPanic s (ppr ty)
74
75 splitBinTy :: String -> Name -> Type -> (Type, Type)
76 splitBinTy s name ty
77   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
78   , tyConName tc == name
79   = (ty1, ty2)
80
81   | otherwise = pprPanic s (ppr ty)
82
83 splitCrossTy :: Type -> (Type, Type)
84 splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
85
86 splitPlusTy :: Type -> (Type, Type)
87 splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
88
89 splitEmbedTy :: Type -> Type
90 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
91
92 splitClosureTy :: Type -> (Type, Type)
93 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
94
95 splitPArrayTy :: Type -> Type
96 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
97
98 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
99 mkBuiltinTyConApp get_tc tys
100   = do
101       tc <- builtin get_tc
102       return $ mkTyConApp tc tys
103
104 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
105 mkBuiltinTyConApps get_tc tys ty
106   = do
107       tc <- builtin get_tc
108       return $ foldr (mk tc) ty tys
109   where
110     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
111
112 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
113 mkBuiltinTyConApps1 get_tc dft [] = return dft
114 mkBuiltinTyConApps1 get_tc dft tys
115   = do
116       tc <- builtin get_tc
117       case tys of
118         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
119         _  -> return $ foldr1 (mk tc) tys
120   where
121     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
122
123 mkPRepr :: [[Type]] -> VM Type
124 mkPRepr [] = return unitTy
125 mkPRepr tys
126   = do
127       embed <- builtin embedTyCon
128       cross <- builtin crossTyCon
129       plus  <- builtin plusTyCon
130
131       let mk_embed ty      = mkTyConApp embed [ty]
132           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
133           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
134
135           mk_tup   []      = unitTy
136           mk_tup   tys     = foldr1 mk_cross tys
137
138           mk_sum   []      = unitTy
139           mk_sum   tys     = foldr1 mk_plus  tys
140
141       return . mk_sum
142              . map (mk_tup . map mk_embed)
143              $ tys
144
145 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
146 mkToPRepr ess
147   = do
148       embed_tc <- builtin embedTyCon
149       embed_dc <- builtin embedDataCon
150       cross_tc <- builtin crossTyCon
151       cross_dc <- builtin crossDataCon
152       plus_tc  <- builtin plusTyCon
153       left_dc  <- builtin leftDataCon
154       right_dc <- builtin rightDataCon
155
156       let mk_embed expr
157             = (mkConApp   embed_dc [Type ty, expr],
158                mkTyConApp embed_tc [ty])
159             where ty = exprType expr
160
161           mk_cross (expr1, ty1) (expr2, ty2)
162             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
163                mkTyConApp cross_tc [ty1, ty2])
164
165           mk_tup [] = (Var unitDataConId, unitTy)
166           mk_tup es = foldr1 mk_cross es
167
168           mk_sum []           = ([Var unitDataConId], unitTy)
169           mk_sum [(expr, ty)] = ([expr], ty)
170           mk_sum ((expr, lty) : es)
171             = let (alts, rty) = mk_sum es
172               in
173               (mkConApp left_dc [Type lty, Type rty, expr]
174                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
175                mkTyConApp plus_tc [lty, rty])
176
177       return . mk_sum $ map (mk_tup . map mk_embed) ess
178
179 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
180 mkFromPRepr scrut res_ty alts
181   = do
182       embed_dc <- builtin embedDataCon
183       cross_dc <- builtin crossDataCon
184       left_dc  <- builtin leftDataCon
185       right_dc <- builtin rightDataCon
186       pa_tc    <- builtin paTyCon
187
188       let un_embed expr ty var res
189             = Case expr (mkWildId ty) res_ty
190                    [(DataAlt embed_dc, [var], res)]
191
192           un_cross expr ty var1 var2 res
193             = Case expr (mkWildId ty) res_ty
194                 [(DataAlt cross_dc, [var1, var2], res)]
195
196           un_tup expr ty []    res = return res
197           un_tup expr ty [var] res = return $ un_embed expr ty var res
198           un_tup expr ty (var : vars) res
199             = do
200                 lv <- newLocalVar FSLIT("x") lty
201                 rv <- newLocalVar FSLIT("y") rty
202                 liftM (un_cross expr ty lv rv
203                       . un_embed (Var lv) lty var)
204                       (un_tup (Var rv) rty vars res)
205             where
206               (lty, rty) = splitCrossTy ty
207
208           un_plus expr ty var1 var2 res1 res2
209             = Case expr (mkWildId ty) res_ty
210                 [(DataAlt left_dc,  [var1], res1),
211                  (DataAlt right_dc, [var2], res2)]
212
213           un_sum expr ty [(vars, res)] = un_tup expr ty vars res
214           un_sum expr ty ((vars, res) : alts)
215             = do
216                 lv <- newLocalVar FSLIT("l") lty
217                 rv <- newLocalVar FSLIT("r") rty
218                 liftM2 (un_plus expr ty lv rv)
219                          (un_tup (Var lv) lty vars res)
220                          (un_sum (Var rv) rty alts)
221             where
222               (lty, rty) = splitPlusTy ty
223
224       un_sum scrut (exprType scrut) alts
225
226 mkClosureType :: Type -> Type -> VM Type
227 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
228
229 mkClosureTypes :: [Type] -> Type -> VM Type
230 mkClosureTypes = mkBuiltinTyConApps closureTyCon
231
232 mkPReprType :: Type -> VM Type
233 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
234
235 mkPADictType :: Type -> VM Type
236 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
237
238 mkPArrayType :: Type -> VM Type
239 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
240
241 parrayReprTyCon :: Type -> VM (TyCon, [Type])
242 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
243
244 parrayReprDataCon :: Type -> VM (DataCon, [Type])
245 parrayReprDataCon ty
246   = do
247       (tc, arg_tys) <- parrayReprTyCon ty
248       let [dc] = tyConDataCons tc
249       return (dc, arg_tys)
250
251 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
252 mkVScrut (ve, le)
253   = do
254       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
255       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
256
257 paDictArgType :: TyVar -> VM (Maybe Type)
258 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
259   where
260     go ty k | Just k' <- kindView k = go ty k'
261     go ty (FunTy k1 k2)
262       = do
263           tv   <- newTyVar FSLIT("a") k1
264           mty1 <- go (TyVarTy tv) k1
265           case mty1 of
266             Just ty1 -> do
267                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
268                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
269             Nothing  -> go ty k2
270
271     go ty k
272       | isLiftedTypeKind k
273       = liftM Just (mkPADictType ty)
274
275     go ty k = return Nothing
276
277 paDictOfType :: Type -> VM CoreExpr
278 paDictOfType ty = paDictOfTyApp ty_fn ty_args
279   where
280     (ty_fn, ty_args) = splitAppTys ty
281
282 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
283 paDictOfTyApp ty_fn ty_args
284   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
285 paDictOfTyApp (TyVarTy tv) ty_args
286   = do
287       dfun <- maybeV (lookupTyVarPA tv)
288       paDFunApply dfun ty_args
289 paDictOfTyApp (TyConApp tc _) ty_args
290   = do
291       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
292       paDFunApply (Var dfun) ty_args
293 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
294
295 paDFunType :: TyCon -> VM Type
296 paDFunType tc
297   = do
298       margs <- mapM paDictArgType tvs
299       res   <- mkPADictType (mkTyConApp tc arg_tys)
300       return . mkForAllTys tvs
301              $ mkFunTys [arg | Just arg <- margs] res
302   where
303     tvs = tyConTyVars tc
304     arg_tys = mkTyVarTys tvs
305
306 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
307 paDFunApply dfun tys
308   = do
309       dicts <- mapM paDictOfType tys
310       return $ mkApps (mkTyApps dfun tys) dicts
311
312 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
313 paMethod method ty
314   = do
315       fn   <- builtin method
316       dict <- paDictOfType ty
317       return $ mkApps (Var fn) [Type ty, dict]
318
319 lengthPA :: CoreExpr -> VM CoreExpr
320 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
321   where
322     ty = splitPArrayTy (exprType x)
323
324 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
325 replicatePA len x = liftM (`mkApps` [len,x])
326                           (paMethod replicatePAVar (exprType x))
327
328 emptyPA :: Type -> VM CoreExpr
329 emptyPA = paMethod emptyPAVar
330
331 liftPA :: CoreExpr -> VM CoreExpr
332 liftPA x
333   = do
334       lc <- builtin liftingContext
335       replicatePA (Var lc) x
336
337 newLocalVVar :: FastString -> Type -> VM VVar
338 newLocalVVar fs vty
339   = do
340       lty <- mkPArrayType vty
341       vv  <- newLocalVar fs vty
342       lv  <- newLocalVar fs lty
343       return (vv,lv)
344
345 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
346 polyAbstract tvs p
347   = localV
348   $ do
349       mdicts <- mapM mk_dict_var tvs
350       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
351       p (mk_lams mdicts)
352   where
353     mk_dict_var tv = do
354                        r <- paDictArgType tv
355                        case r of
356                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
357                          Nothing -> return Nothing
358
359     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
360
361 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
362 polyApply expr tys
363   = do
364       dicts <- mapM paDictOfType tys
365       return $ expr `mkTyApps` tys `mkApps` dicts
366
367 polyVApply :: VExpr -> [Type] -> VM VExpr
368 polyVApply expr tys
369   = do
370       dicts <- mapM paDictOfType tys
371       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
372
373 hoistBinding :: Var -> CoreExpr -> VM ()
374 hoistBinding v e = updGEnv $ \env ->
375   env { global_bindings = (v,e) : global_bindings env }
376
377 hoistExpr :: FastString -> CoreExpr -> VM Var
378 hoistExpr fs expr
379   = do
380       var <- newLocalVar fs (exprType expr)
381       hoistBinding var expr
382       return var
383
384 hoistVExpr :: VExpr -> VM VVar
385 hoistVExpr (ve, le)
386   = do
387       fs <- getBindName
388       vv <- hoistExpr ('v' `consFS` fs) ve
389       lv <- hoistExpr ('l' `consFS` fs) le
390       return (vv, lv)
391
392 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
393 hoistPolyVExpr tvs p
394   = do
395       expr <- closedV . polyAbstract tvs $ \abstract ->
396               liftM (mapVect abstract) p
397       fn   <- hoistVExpr expr
398       polyVApply (vVar fn) (mkTyVarTys tvs)
399
400 takeHoisted :: VM [(Var, CoreExpr)]
401 takeHoisted
402   = do
403       env <- readGEnv id
404       setGEnv $ env { global_bindings = [] }
405       return $ global_bindings env
406
407 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
408 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
409   = do
410       dict <- paDictOfType env_ty
411       mkv  <- builtin mkClosureVar
412       mkl  <- builtin mkClosurePVar
413       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
414               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
415
416 mkClosureApp :: VExpr -> VExpr -> VM VExpr
417 mkClosureApp (vclo, lclo) (varg, larg)
418   = do
419       vapply <- builtin applyClosureVar
420       lapply <- builtin applyClosurePVar
421       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
422               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
423   where
424     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
425
426 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
427 buildClosures tvs vars [] res_ty mk_body
428   = mk_body
429 buildClosures tvs vars [arg_ty] res_ty mk_body
430   = buildClosure tvs vars arg_ty res_ty mk_body
431 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
432   = do
433       res_ty' <- mkClosureTypes arg_tys res_ty
434       arg <- newLocalVVar FSLIT("x") arg_ty
435       buildClosure tvs vars arg_ty res_ty'
436         . hoistPolyVExpr tvs
437         $ do
438             lc <- builtin liftingContext
439             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
440             return $ vLams lc (vars ++ [arg]) clo
441
442 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
443 --   where
444 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
445 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
446 --
447 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
448 buildClosure tvs vars arg_ty res_ty mk_body
449   = do
450       (env_ty, env, bind) <- buildEnv vars
451       env_bndr <- newLocalVVar FSLIT("env") env_ty
452       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
453
454       fn <- hoistPolyVExpr tvs
455           $ do
456               lc    <- builtin liftingContext
457               body  <- mk_body
458               body' <- bind (vVar env_bndr)
459                             (vVarApps lc body (vars ++ [arg_bndr]))
460               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
461
462       mkClosure arg_ty res_ty env_ty fn env
463
464 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
465 buildEnv vvs
466   = do
467       lc <- builtin liftingContext
468       let (ty, venv, vbind) = mkVectEnv tys vs
469       (lenv, lbind) <- mkLiftEnv lc tys ls
470       return (ty, (venv, lenv),
471               \(venv,lenv) (vbody,lbody) ->
472               do
473                 let vbody' = vbind venv vbody
474                 lbody' <- lbind lenv lbody
475                 return (vbody', lbody'))
476   where
477     (vs,ls) = unzip vvs
478     tys     = map idType vs
479
480 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
481 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
482 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
483 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
484                         \env body -> Case env (mkWildId ty) (exprType body)
485                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
486   where
487     ty = mkCoreTupTy tys
488
489 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
490 mkLiftEnv lc [ty] [v]
491   = return (Var v, \env body ->
492                    do
493                      len <- lengthPA (Var v)
494                      return . Let (NonRec v env)
495                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
496
497 -- NOTE: this transparently deals with empty environments
498 mkLiftEnv lc tys vs
499   = do
500       (env_tc, env_tyargs) <- parrayReprTyCon vty
501       let [env_con] = tyConDataCons env_tc
502           
503           env = Var (dataConWrapId env_con)
504                 `mkTyApps`  env_tyargs
505                 `mkVarApps` (lc : vs)
506
507           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
508                           in
509                           return $ Case scrut (mkWildId (exprType scrut))
510                                         (exprType body)
511                                         [(DataAlt env_con, lc : bndrs, body)]
512       return (env, bind)
513   where
514     vty = mkCoreTupTy tys
515
516     bndrs | null vs   = [mkWildId unitTy]
517           | otherwise = vs
518