Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPReprType, mkPReprAlts,
7   mkPADictType, mkPArrayType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   paDictArgType, paDictOfType, paDFunType,
10   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
11   polyAbstract, polyApply, polyVApply,
12   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
13   buildClosure, buildClosures,
14   mkClosureApp
15 ) where
16
17 #include "HsVersions.h"
18
19 import VectCore
20 import VectMonad
21
22 import DsUtils
23 import CoreSyn
24 import CoreUtils
25 import Type
26 import TypeRep
27 import TyCon
28 import DataCon            ( DataCon, dataConWrapId, dataConTag )
29 import Var
30 import Id                 ( mkWildId )
31 import MkId               ( unwrapFamInstScrut )
32 import Name               ( Name )
33 import PrelNames
34 import TysWiredIn
35 import BasicTypes         ( Boxity(..) )
36
37 import Outputable
38 import FastString
39
40 import Control.Monad         ( liftM, zipWithM_ )
41
42 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
43 collectAnnTypeArgs expr = go expr []
44   where
45     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46     go e                             tys = (e, tys)
47
48 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
49 collectAnnTypeBinders expr = go [] expr
50   where
51     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
52     go bs e                           = (reverse bs, e)
53
54 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnValBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isId b = go (b:bs) e
58     go bs e                        = (reverse bs, e)
59
60 isAnnTypeArg :: AnnExpr b ann -> Bool
61 isAnnTypeArg (_, AnnType t) = True
62 isAnnTypeArg _              = False
63
64 mkDataConTag :: DataCon -> CoreExpr
65 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
66
67 splitUnTy :: String -> Name -> Type -> Type
68 splitUnTy s name ty
69   | Just (tc, [ty']) <- splitTyConApp_maybe ty
70   , tyConName tc == name
71   = ty'
72
73   | otherwise = pprPanic s (ppr ty)
74
75 splitBinTy :: String -> Name -> Type -> (Type, Type)
76 splitBinTy s name ty
77   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
78   , tyConName tc == name
79   = (ty1, ty2)
80
81   | otherwise = pprPanic s (ppr ty)
82
83 splitCrossTy :: Type -> (Type, Type)
84 splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
85
86 splitPlusTy :: Type -> (Type, Type)
87 splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
88
89 splitEmbedTy :: Type -> Type
90 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
91
92 splitClosureTy :: Type -> (Type, Type)
93 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
94
95 splitPArrayTy :: Type -> Type
96 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
97
98 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
99 mkBuiltinTyConApp get_tc tys
100   = do
101       tc <- builtin get_tc
102       return $ mkTyConApp tc tys
103
104 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
105 mkBuiltinTyConApps get_tc tys ty
106   = do
107       tc <- builtin get_tc
108       return $ foldr (mk tc) ty tys
109   where
110     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
111
112 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
113 mkBuiltinTyConApps1 get_tc dft [] = return dft
114 mkBuiltinTyConApps1 get_tc dft tys
115   = do
116       tc <- builtin get_tc
117       case tys of
118         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
119         _  -> return $ foldr1 (mk tc) tys
120   where
121     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
122
123 mkPReprType :: [[Type]] -> VM Type
124 mkPReprType [] = return unitTy
125 mkPReprType tys
126   = do
127       embed <- builtin embedTyCon
128       cross <- builtin crossTyCon
129       plus  <- builtin plusTyCon
130
131       let mk_embed ty      = mkTyConApp embed [ty]
132           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
133           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
134
135           mk_tup   []      = unitTy
136           mk_tup   tys     = foldr1 mk_cross tys
137
138           mk_sum   []      = unitTy
139           mk_sum   tys     = foldr1 mk_plus  tys
140
141       return . mk_sum
142              . map (mk_tup . map mk_embed)
143              $ tys
144
145 mkPReprAlts :: [[CoreExpr]] -> VM ([CoreExpr], Type)
146 mkPReprAlts ess
147   = do
148       embed_tc <- builtin embedTyCon
149       embed_dc <- builtin embedDataCon
150       cross_tc <- builtin crossTyCon
151       cross_dc <- builtin crossDataCon
152       plus_tc  <- builtin plusTyCon
153       left_dc  <- builtin leftDataCon
154       right_dc <- builtin rightDataCon
155
156       let mk_embed (expr, ty, pa)
157             = (mkConApp   embed_dc [Type ty, pa, expr],
158                mkTyConApp embed_tc [ty])
159
160           mk_cross (expr1, ty1) (expr2, ty2)
161             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
162                mkTyConApp cross_tc [ty1, ty2])
163
164           mk_tup [] = (Var unitDataConId, unitTy)
165           mk_tup es = foldr1 mk_cross es
166
167           mk_sum []           = ([Var unitDataConId], unitTy)
168           mk_sum [(expr, ty)] = ([expr], ty)
169           mk_sum ((expr, lty) : es)
170             = let (alts, rty) = mk_sum es
171               in
172               (mkConApp left_dc [Type lty, Type rty, expr]
173                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
174                mkTyConApp plus_tc [lty, rty])
175       
176       liftM (mk_sum . map (mk_tup . map mk_embed))
177             (mapM (mapM init) ess)
178   where
179     init expr = let ty = exprType expr
180                 in do
181                      pa <- paDictOfType ty
182                      return (expr, ty, pa)
183
184 mkClosureType :: Type -> Type -> VM Type
185 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
186
187 mkClosureTypes :: [Type] -> Type -> VM Type
188 mkClosureTypes = mkBuiltinTyConApps closureTyCon
189
190 mkPADictType :: Type -> VM Type
191 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
192
193 mkPArrayType :: Type -> VM Type
194 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
195
196 parrayReprTyCon :: Type -> VM (TyCon, [Type])
197 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
198
199 parrayReprDataCon :: Type -> VM (DataCon, [Type])
200 parrayReprDataCon ty
201   = do
202       (tc, arg_tys) <- parrayReprTyCon ty
203       let [dc] = tyConDataCons tc
204       return (dc, arg_tys)
205
206 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
207 mkVScrut (ve, le)
208   = do
209       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
210       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
211
212 paDictArgType :: TyVar -> VM (Maybe Type)
213 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
214   where
215     go ty k | Just k' <- kindView k = go ty k'
216     go ty (FunTy k1 k2)
217       = do
218           tv   <- newTyVar FSLIT("a") k1
219           mty1 <- go (TyVarTy tv) k1
220           case mty1 of
221             Just ty1 -> do
222                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
223                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
224             Nothing  -> go ty k2
225
226     go ty k
227       | isLiftedTypeKind k
228       = liftM Just (mkPADictType ty)
229
230     go ty k = return Nothing
231
232 paDictOfType :: Type -> VM CoreExpr
233 paDictOfType ty = paDictOfTyApp ty_fn ty_args
234   where
235     (ty_fn, ty_args) = splitAppTys ty
236
237 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
238 paDictOfTyApp ty_fn ty_args
239   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
240 paDictOfTyApp (TyVarTy tv) ty_args
241   = do
242       dfun <- maybeV (lookupTyVarPA tv)
243       paDFunApply dfun ty_args
244 paDictOfTyApp (TyConApp tc _) ty_args
245   = do
246       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
247       paDFunApply (Var dfun) ty_args
248 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
249
250 paDFunType :: TyCon -> VM Type
251 paDFunType tc
252   = do
253       margs <- mapM paDictArgType tvs
254       res   <- mkPADictType (mkTyConApp tc arg_tys)
255       return . mkForAllTys tvs
256              $ mkFunTys [arg | Just arg <- margs] res
257   where
258     tvs = tyConTyVars tc
259     arg_tys = mkTyVarTys tvs
260
261 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
262 paDFunApply dfun tys
263   = do
264       dicts <- mapM paDictOfType tys
265       return $ mkApps (mkTyApps dfun tys) dicts
266
267 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
268 paMethod method ty
269   = do
270       fn   <- builtin method
271       dict <- paDictOfType ty
272       return $ mkApps (Var fn) [Type ty, dict]
273
274 lengthPA :: CoreExpr -> VM CoreExpr
275 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
276   where
277     ty = splitPArrayTy (exprType x)
278
279 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
280 replicatePA len x = liftM (`mkApps` [len,x])
281                           (paMethod replicatePAVar (exprType x))
282
283 emptyPA :: Type -> VM CoreExpr
284 emptyPA = paMethod emptyPAVar
285
286 liftPA :: CoreExpr -> VM CoreExpr
287 liftPA x
288   = do
289       lc <- builtin liftingContext
290       replicatePA (Var lc) x
291
292 newLocalVVar :: FastString -> Type -> VM VVar
293 newLocalVVar fs vty
294   = do
295       lty <- mkPArrayType vty
296       vv  <- newLocalVar fs vty
297       lv  <- newLocalVar fs lty
298       return (vv,lv)
299
300 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
301 polyAbstract tvs p
302   = localV
303   $ do
304       mdicts <- mapM mk_dict_var tvs
305       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
306       p (mk_lams mdicts)
307   where
308     mk_dict_var tv = do
309                        r <- paDictArgType tv
310                        case r of
311                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
312                          Nothing -> return Nothing
313
314     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
315
316 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
317 polyApply expr tys
318   = do
319       dicts <- mapM paDictOfType tys
320       return $ expr `mkTyApps` tys `mkApps` dicts
321
322 polyVApply :: VExpr -> [Type] -> VM VExpr
323 polyVApply expr tys
324   = do
325       dicts <- mapM paDictOfType tys
326       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
327
328 hoistBinding :: Var -> CoreExpr -> VM ()
329 hoistBinding v e = updGEnv $ \env ->
330   env { global_bindings = (v,e) : global_bindings env }
331
332 hoistExpr :: FastString -> CoreExpr -> VM Var
333 hoistExpr fs expr
334   = do
335       var <- newLocalVar fs (exprType expr)
336       hoistBinding var expr
337       return var
338
339 hoistVExpr :: VExpr -> VM VVar
340 hoistVExpr (ve, le)
341   = do
342       fs <- getBindName
343       vv <- hoistExpr ('v' `consFS` fs) ve
344       lv <- hoistExpr ('l' `consFS` fs) le
345       return (vv, lv)
346
347 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
348 hoistPolyVExpr tvs p
349   = do
350       expr <- closedV . polyAbstract tvs $ \abstract ->
351               liftM (mapVect abstract) p
352       fn   <- hoistVExpr expr
353       polyVApply (vVar fn) (mkTyVarTys tvs)
354
355 takeHoisted :: VM [(Var, CoreExpr)]
356 takeHoisted
357   = do
358       env <- readGEnv id
359       setGEnv $ env { global_bindings = [] }
360       return $ global_bindings env
361
362 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
363 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
364   = do
365       dict <- paDictOfType env_ty
366       mkv  <- builtin mkClosureVar
367       mkl  <- builtin mkClosurePVar
368       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
369               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
370
371 mkClosureApp :: VExpr -> VExpr -> VM VExpr
372 mkClosureApp (vclo, lclo) (varg, larg)
373   = do
374       vapply <- builtin applyClosureVar
375       lapply <- builtin applyClosurePVar
376       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
377               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
378   where
379     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
380
381 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
382 buildClosures tvs vars [] res_ty mk_body
383   = mk_body
384 buildClosures tvs vars [arg_ty] res_ty mk_body
385   = buildClosure tvs vars arg_ty res_ty mk_body
386 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
387   = do
388       res_ty' <- mkClosureTypes arg_tys res_ty
389       arg <- newLocalVVar FSLIT("x") arg_ty
390       buildClosure tvs vars arg_ty res_ty'
391         . hoistPolyVExpr tvs
392         $ do
393             lc <- builtin liftingContext
394             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
395             return $ vLams lc (vars ++ [arg]) clo
396
397 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
398 --   where
399 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
400 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
401 --
402 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
403 buildClosure tvs vars arg_ty res_ty mk_body
404   = do
405       (env_ty, env, bind) <- buildEnv vars
406       env_bndr <- newLocalVVar FSLIT("env") env_ty
407       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
408
409       fn <- hoistPolyVExpr tvs
410           $ do
411               lc    <- builtin liftingContext
412               body  <- mk_body
413               body' <- bind (vVar env_bndr)
414                             (vVarApps lc body (vars ++ [arg_bndr]))
415               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
416
417       mkClosure arg_ty res_ty env_ty fn env
418
419 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
420 buildEnv vvs
421   = do
422       lc <- builtin liftingContext
423       let (ty, venv, vbind) = mkVectEnv tys vs
424       (lenv, lbind) <- mkLiftEnv lc tys ls
425       return (ty, (venv, lenv),
426               \(venv,lenv) (vbody,lbody) ->
427               do
428                 let vbody' = vbind venv vbody
429                 lbody' <- lbind lenv lbody
430                 return (vbody', lbody'))
431   where
432     (vs,ls) = unzip vvs
433     tys     = map idType vs
434
435 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
436 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
437 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
438 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
439                         \env body -> Case env (mkWildId ty) (exprType body)
440                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
441   where
442     ty = mkCoreTupTy tys
443
444 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
445 mkLiftEnv lc [ty] [v]
446   = return (Var v, \env body ->
447                    do
448                      len <- lengthPA (Var v)
449                      return . Let (NonRec v env)
450                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
451
452 -- NOTE: this transparently deals with empty environments
453 mkLiftEnv lc tys vs
454   = do
455       (env_tc, env_tyargs) <- parrayReprTyCon vty
456       let [env_con] = tyConDataCons env_tc
457           
458           env = Var (dataConWrapId env_con)
459                 `mkTyApps`  env_tyargs
460                 `mkVarApps` (lc : vs)
461
462           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
463                           in
464                           return $ Case scrut (mkWildId (exprType scrut))
465                                         (exprType body)
466                                         [(DataAlt env_con, lc : bndrs, body)]
467       return (env, bind)
468   where
469     vty = mkCoreTupTy tys
470
471     bndrs | null vs   = [mkWildId unitTy]
472           | otherwise = vs
473