Generate conversion from PRepr to original type
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPRepr, mkToPRepr, mkFromPRepr,
7   mkPADictType, mkPArrayType, mkPReprType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   paDictArgType, paDictOfType, paDFunType,
10   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
11   polyAbstract, polyApply, polyVApply,
12   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
13   buildClosure, buildClosures,
14   mkClosureApp
15 ) where
16
17 #include "HsVersions.h"
18
19 import VectCore
20 import VectMonad
21
22 import DsUtils
23 import CoreSyn
24 import CoreUtils
25 import Type
26 import TypeRep
27 import TyCon
28 import DataCon            ( DataCon, dataConWrapId, dataConTag )
29 import Var
30 import Id                 ( mkWildId )
31 import MkId               ( unwrapFamInstScrut )
32 import Name               ( Name )
33 import PrelNames
34 import TysWiredIn
35 import BasicTypes         ( Boxity(..) )
36
37 import Outputable
38 import FastString
39
40 import Control.Monad         ( liftM, liftM2, zipWithM_ )
41
42 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
43 collectAnnTypeArgs expr = go expr []
44   where
45     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46     go e                             tys = (e, tys)
47
48 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
49 collectAnnTypeBinders expr = go [] expr
50   where
51     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
52     go bs e                           = (reverse bs, e)
53
54 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnValBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isId b = go (b:bs) e
58     go bs e                        = (reverse bs, e)
59
60 isAnnTypeArg :: AnnExpr b ann -> Bool
61 isAnnTypeArg (_, AnnType t) = True
62 isAnnTypeArg _              = False
63
64 mkDataConTag :: DataCon -> CoreExpr
65 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
66
67 splitUnTy :: String -> Name -> Type -> Type
68 splitUnTy s name ty
69   | Just (tc, [ty']) <- splitTyConApp_maybe ty
70   , tyConName tc == name
71   = ty'
72
73   | otherwise = pprPanic s (ppr ty)
74
75 splitBinTy :: String -> Name -> Type -> (Type, Type)
76 splitBinTy s name ty
77   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
78   , tyConName tc == name
79   = (ty1, ty2)
80
81   | otherwise = pprPanic s (ppr ty)
82
83 splitCrossTy :: Type -> (Type, Type)
84 splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
85
86 splitPlusTy :: Type -> (Type, Type)
87 splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
88
89 splitEmbedTy :: Type -> Type
90 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
91
92 splitClosureTy :: Type -> (Type, Type)
93 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
94
95 splitPArrayTy :: Type -> Type
96 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
97
98 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
99 mkBuiltinTyConApp get_tc tys
100   = do
101       tc <- builtin get_tc
102       return $ mkTyConApp tc tys
103
104 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
105 mkBuiltinTyConApps get_tc tys ty
106   = do
107       tc <- builtin get_tc
108       return $ foldr (mk tc) ty tys
109   where
110     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
111
112 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
113 mkBuiltinTyConApps1 get_tc dft [] = return dft
114 mkBuiltinTyConApps1 get_tc dft tys
115   = do
116       tc <- builtin get_tc
117       case tys of
118         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
119         _  -> return $ foldr1 (mk tc) tys
120   where
121     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
122
123 mkPRepr :: [[Type]] -> VM Type
124 mkPRepr [] = return unitTy
125 mkPRepr tys
126   = do
127       embed <- builtin embedTyCon
128       cross <- builtin crossTyCon
129       plus  <- builtin plusTyCon
130
131       let mk_embed ty      = mkTyConApp embed [ty]
132           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
133           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
134
135           mk_tup   []      = unitTy
136           mk_tup   tys     = foldr1 mk_cross tys
137
138           mk_sum   []      = unitTy
139           mk_sum   tys     = foldr1 mk_plus  tys
140
141       return . mk_sum
142              . map (mk_tup . map mk_embed)
143              $ tys
144
145 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
146 mkToPRepr ess
147   = do
148       embed_tc <- builtin embedTyCon
149       embed_dc <- builtin embedDataCon
150       cross_tc <- builtin crossTyCon
151       cross_dc <- builtin crossDataCon
152       plus_tc  <- builtin plusTyCon
153       left_dc  <- builtin leftDataCon
154       right_dc <- builtin rightDataCon
155
156       let mk_embed (expr, ty, pa)
157             = (mkConApp   embed_dc [Type ty, pa, expr],
158                mkTyConApp embed_tc [ty])
159
160           mk_cross (expr1, ty1) (expr2, ty2)
161             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
162                mkTyConApp cross_tc [ty1, ty2])
163
164           mk_tup [] = (Var unitDataConId, unitTy)
165           mk_tup es = foldr1 mk_cross es
166
167           mk_sum []           = ([Var unitDataConId], unitTy)
168           mk_sum [(expr, ty)] = ([expr], ty)
169           mk_sum ((expr, lty) : es)
170             = let (alts, rty) = mk_sum es
171               in
172               (mkConApp left_dc [Type lty, Type rty, expr]
173                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
174                mkTyConApp plus_tc [lty, rty])
175       
176       liftM (mk_sum . map (mk_tup . map mk_embed))
177             (mapM (mapM init) ess)
178   where
179     init expr = let ty = exprType expr
180                 in do
181                      pa <- paDictOfType ty
182                      return (expr, ty, pa)
183
184 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
185 mkFromPRepr scrut res_ty alts
186   = do
187       embed_dc <- builtin embedDataCon
188       cross_dc <- builtin crossDataCon
189       left_dc  <- builtin leftDataCon
190       right_dc <- builtin rightDataCon
191       pa_tc    <- builtin paTyCon
192
193       let un_embed expr ty var res
194             = do
195                 pa <- newLocalVar FSLIT("pa") (mkTyConApp pa_tc [idType var])
196                 return $ Case expr (mkWildId ty) res_ty
197                          [(DataAlt embed_dc, [pa, var], res)]
198
199           un_cross expr ty var1 var2 res
200             = Case expr (mkWildId ty) res_ty
201                 [(DataAlt cross_dc, [var1, var2], res)]
202
203           un_tup expr ty []    res = return res
204           un_tup expr ty [var] res = un_embed expr ty var res
205           un_tup expr ty (var : vars) res
206             = do
207                 lv <- newLocalVar FSLIT("x") lty
208                 rv <- newLocalVar FSLIT("y") rty
209                 liftM (un_cross expr ty lv rv)
210                         (un_embed (Var lv) lty var
211                          =<< un_tup (Var rv) rty vars res)
212             where
213               (lty, rty) = splitCrossTy ty
214
215           un_plus expr ty var1 var2 res1 res2
216             = Case expr (mkWildId ty) res_ty
217                 [(DataAlt left_dc,  [var1], res1),
218                  (DataAlt right_dc, [var2], res2)]
219
220           un_sum expr ty [(vars, res)] = un_tup expr ty vars res
221           un_sum expr ty ((vars, res) : alts)
222             = do
223                 lv <- newLocalVar FSLIT("l") lty
224                 rv <- newLocalVar FSLIT("r") rty
225                 liftM2 (un_plus expr ty lv rv)
226                          (un_tup (Var lv) lty vars res)
227                          (un_sum (Var rv) rty alts)
228             where
229               (lty, rty) = splitPlusTy ty
230
231       un_sum scrut (exprType scrut) alts
232
233 mkClosureType :: Type -> Type -> VM Type
234 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
235
236 mkClosureTypes :: [Type] -> Type -> VM Type
237 mkClosureTypes = mkBuiltinTyConApps closureTyCon
238
239 mkPReprType :: Type -> VM Type
240 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
241
242 mkPADictType :: Type -> VM Type
243 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
244
245 mkPArrayType :: Type -> VM Type
246 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
247
248 parrayReprTyCon :: Type -> VM (TyCon, [Type])
249 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
250
251 parrayReprDataCon :: Type -> VM (DataCon, [Type])
252 parrayReprDataCon ty
253   = do
254       (tc, arg_tys) <- parrayReprTyCon ty
255       let [dc] = tyConDataCons tc
256       return (dc, arg_tys)
257
258 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
259 mkVScrut (ve, le)
260   = do
261       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
262       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
263
264 paDictArgType :: TyVar -> VM (Maybe Type)
265 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
266   where
267     go ty k | Just k' <- kindView k = go ty k'
268     go ty (FunTy k1 k2)
269       = do
270           tv   <- newTyVar FSLIT("a") k1
271           mty1 <- go (TyVarTy tv) k1
272           case mty1 of
273             Just ty1 -> do
274                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
275                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
276             Nothing  -> go ty k2
277
278     go ty k
279       | isLiftedTypeKind k
280       = liftM Just (mkPADictType ty)
281
282     go ty k = return Nothing
283
284 paDictOfType :: Type -> VM CoreExpr
285 paDictOfType ty = paDictOfTyApp ty_fn ty_args
286   where
287     (ty_fn, ty_args) = splitAppTys ty
288
289 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
290 paDictOfTyApp ty_fn ty_args
291   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
292 paDictOfTyApp (TyVarTy tv) ty_args
293   = do
294       dfun <- maybeV (lookupTyVarPA tv)
295       paDFunApply dfun ty_args
296 paDictOfTyApp (TyConApp tc _) ty_args
297   = do
298       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
299       paDFunApply (Var dfun) ty_args
300 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
301
302 paDFunType :: TyCon -> VM Type
303 paDFunType tc
304   = do
305       margs <- mapM paDictArgType tvs
306       res   <- mkPADictType (mkTyConApp tc arg_tys)
307       return . mkForAllTys tvs
308              $ mkFunTys [arg | Just arg <- margs] res
309   where
310     tvs = tyConTyVars tc
311     arg_tys = mkTyVarTys tvs
312
313 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
314 paDFunApply dfun tys
315   = do
316       dicts <- mapM paDictOfType tys
317       return $ mkApps (mkTyApps dfun tys) dicts
318
319 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
320 paMethod method ty
321   = do
322       fn   <- builtin method
323       dict <- paDictOfType ty
324       return $ mkApps (Var fn) [Type ty, dict]
325
326 lengthPA :: CoreExpr -> VM CoreExpr
327 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
328   where
329     ty = splitPArrayTy (exprType x)
330
331 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
332 replicatePA len x = liftM (`mkApps` [len,x])
333                           (paMethod replicatePAVar (exprType x))
334
335 emptyPA :: Type -> VM CoreExpr
336 emptyPA = paMethod emptyPAVar
337
338 liftPA :: CoreExpr -> VM CoreExpr
339 liftPA x
340   = do
341       lc <- builtin liftingContext
342       replicatePA (Var lc) x
343
344 newLocalVVar :: FastString -> Type -> VM VVar
345 newLocalVVar fs vty
346   = do
347       lty <- mkPArrayType vty
348       vv  <- newLocalVar fs vty
349       lv  <- newLocalVar fs lty
350       return (vv,lv)
351
352 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
353 polyAbstract tvs p
354   = localV
355   $ do
356       mdicts <- mapM mk_dict_var tvs
357       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
358       p (mk_lams mdicts)
359   where
360     mk_dict_var tv = do
361                        r <- paDictArgType tv
362                        case r of
363                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
364                          Nothing -> return Nothing
365
366     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
367
368 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
369 polyApply expr tys
370   = do
371       dicts <- mapM paDictOfType tys
372       return $ expr `mkTyApps` tys `mkApps` dicts
373
374 polyVApply :: VExpr -> [Type] -> VM VExpr
375 polyVApply expr tys
376   = do
377       dicts <- mapM paDictOfType tys
378       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
379
380 hoistBinding :: Var -> CoreExpr -> VM ()
381 hoistBinding v e = updGEnv $ \env ->
382   env { global_bindings = (v,e) : global_bindings env }
383
384 hoistExpr :: FastString -> CoreExpr -> VM Var
385 hoistExpr fs expr
386   = do
387       var <- newLocalVar fs (exprType expr)
388       hoistBinding var expr
389       return var
390
391 hoistVExpr :: VExpr -> VM VVar
392 hoistVExpr (ve, le)
393   = do
394       fs <- getBindName
395       vv <- hoistExpr ('v' `consFS` fs) ve
396       lv <- hoistExpr ('l' `consFS` fs) le
397       return (vv, lv)
398
399 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
400 hoistPolyVExpr tvs p
401   = do
402       expr <- closedV . polyAbstract tvs $ \abstract ->
403               liftM (mapVect abstract) p
404       fn   <- hoistVExpr expr
405       polyVApply (vVar fn) (mkTyVarTys tvs)
406
407 takeHoisted :: VM [(Var, CoreExpr)]
408 takeHoisted
409   = do
410       env <- readGEnv id
411       setGEnv $ env { global_bindings = [] }
412       return $ global_bindings env
413
414 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
415 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
416   = do
417       dict <- paDictOfType env_ty
418       mkv  <- builtin mkClosureVar
419       mkl  <- builtin mkClosurePVar
420       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
421               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
422
423 mkClosureApp :: VExpr -> VExpr -> VM VExpr
424 mkClosureApp (vclo, lclo) (varg, larg)
425   = do
426       vapply <- builtin applyClosureVar
427       lapply <- builtin applyClosurePVar
428       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
429               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
430   where
431     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
432
433 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
434 buildClosures tvs vars [] res_ty mk_body
435   = mk_body
436 buildClosures tvs vars [arg_ty] res_ty mk_body
437   = buildClosure tvs vars arg_ty res_ty mk_body
438 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
439   = do
440       res_ty' <- mkClosureTypes arg_tys res_ty
441       arg <- newLocalVVar FSLIT("x") arg_ty
442       buildClosure tvs vars arg_ty res_ty'
443         . hoistPolyVExpr tvs
444         $ do
445             lc <- builtin liftingContext
446             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
447             return $ vLams lc (vars ++ [arg]) clo
448
449 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
450 --   where
451 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
452 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
453 --
454 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
455 buildClosure tvs vars arg_ty res_ty mk_body
456   = do
457       (env_ty, env, bind) <- buildEnv vars
458       env_bndr <- newLocalVVar FSLIT("env") env_ty
459       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
460
461       fn <- hoistPolyVExpr tvs
462           $ do
463               lc    <- builtin liftingContext
464               body  <- mk_body
465               body' <- bind (vVar env_bndr)
466                             (vVarApps lc body (vars ++ [arg_bndr]))
467               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
468
469       mkClosure arg_ty res_ty env_ty fn env
470
471 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
472 buildEnv vvs
473   = do
474       lc <- builtin liftingContext
475       let (ty, venv, vbind) = mkVectEnv tys vs
476       (lenv, lbind) <- mkLiftEnv lc tys ls
477       return (ty, (venv, lenv),
478               \(venv,lenv) (vbody,lbody) ->
479               do
480                 let vbody' = vbind venv vbody
481                 lbody' <- lbind lenv lbody
482                 return (vbody', lbody'))
483   where
484     (vs,ls) = unzip vvs
485     tys     = map idType vs
486
487 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
488 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
489 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
490 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
491                         \env body -> Case env (mkWildId ty) (exprType body)
492                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
493   where
494     ty = mkCoreTupTy tys
495
496 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
497 mkLiftEnv lc [ty] [v]
498   = return (Var v, \env body ->
499                    do
500                      len <- lengthPA (Var v)
501                      return . Let (NonRec v env)
502                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
503
504 -- NOTE: this transparently deals with empty environments
505 mkLiftEnv lc tys vs
506   = do
507       (env_tc, env_tyargs) <- parrayReprTyCon vty
508       let [env_con] = tyConDataCons env_tc
509           
510           env = Var (dataConWrapId env_con)
511                 `mkTyApps`  env_tyargs
512                 `mkVarApps` (lc : vs)
513
514           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
515                           in
516                           return $ Case scrut (mkWildId (exprType scrut))
517                                         (exprType body)
518                                         [(DataAlt env_con, lc : bndrs, body)]
519       return (env, bind)
520   where
521     vty = mkCoreTupTy tys
522
523     bndrs | null vs   = [mkWildId unitTy]
524           | otherwise = vs
525