Use n-ary sums and products for NDP's generic representation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPRepr, mkToPRepr, mkFromPRepr,
7   mkPADictType, mkPArrayType, mkPReprType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   prDictOfType, prCoerce,
10   paDictArgType, paDictOfType, paDFunType,
11   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12   polyAbstract, polyApply, polyVApply,
13   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14   buildClosure, buildClosures,
15   mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Coercion
27 import Type
28 import TypeRep
29 import TyCon
30 import DataCon            ( DataCon, dataConWrapId, dataConTag )
31 import Var
32 import Id                 ( mkWildId )
33 import MkId               ( unwrapFamInstScrut )
34 import Name               ( Name )
35 import PrelNames
36 import TysWiredIn
37 import BasicTypes         ( Boxity(..) )
38
39 import Outputable
40 import FastString
41
42 import Data.List             ( zipWith4 )
43 import Control.Monad         ( liftM, liftM2, zipWithM_ )
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47   where
48     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49     go e                             tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55     go bs e                           = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59   where
60     go bs (_, AnnLam b e) | isId b = go (b:bs) e
61     go bs e                        = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType t) = True
65 isAnnTypeArg _              = False
66
67 mkDataConTag :: DataCon -> CoreExpr
68 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
69
70 splitUnTy :: String -> Name -> Type -> Type
71 splitUnTy s name ty
72   | Just (tc, [ty']) <- splitTyConApp_maybe ty
73   , tyConName tc == name
74   = ty'
75
76   | otherwise = pprPanic s (ppr ty)
77
78 splitBinTy :: String -> Name -> Type -> (Type, Type)
79 splitBinTy s name ty
80   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
81   , tyConName tc == name
82   = (ty1, ty2)
83
84   | otherwise = pprPanic s (ppr ty)
85
86 splitFixedTyConApp :: TyCon -> Type -> [Type]
87 splitFixedTyConApp tc ty
88   | Just (tc', tys) <- splitTyConApp_maybe ty
89   , tc == tc'
90   = tys
91
92   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
93
94 splitEmbedTy :: Type -> Type
95 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105   = do
106       tc <- builtin get_tc
107       return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111   = do
112       tc <- builtin get_tc
113       return $ foldr (mk tc) ty tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120   = do
121       tc <- builtin get_tc
122       case tys of
123         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124         _  -> return $ foldr1 (mk tc) tys
125   where
126     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 mkPRepr :: [[Type]] -> VM Type
129 mkPRepr tys
130   = do
131       embed_tc <- builtin embedTyCon
132       sum_tcs  <- builtins sumTyCon
133       prod_tcs <- builtins prodTyCon
134
135       let mk_sum []   = unitTy
136           mk_sum [ty] = ty
137           mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
138
139           mk_prod []   = unitTy
140           mk_prod [ty] = ty
141           mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
142
143           mk_embed ty = mkTyConApp embed_tc [ty]
144
145       return . mk_sum
146              . map (mk_prod . map mk_embed)
147              $ tys
148
149 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
150 mkToPRepr ess
151   = do
152       embed_tc <- builtin embedTyCon
153       embed_dc <- builtin embedDataCon
154       sum_tcs  <- builtins sumTyCon
155       prod_tcs <- builtins prodTyCon
156
157       let mk_sum [] = ([Var unitDataConId], unitTy)
158           mk_sum [(expr, ty)] = ([expr], ty)
159           mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
160                        mkTyConApp sum_tc tys)
161             where
162               (exprs, tys)   = unzip es
163               sum_tc         = sum_tcs (length es)
164               mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
165
166           mk_prod [] = (Var unitDataConId, unitTy)
167           mk_prod [(expr, ty)] = (expr, ty)
168           mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
169                         mkTyConApp prod_tc tys)
170             where
171               (exprs, tys) = unzip es
172               prod_tc      = prod_tcs (length es)
173               [prod_dc]    = tyConDataCons prod_tc
174
175           mk_embed expr = (mkConApp embed_dc [Type ty, expr],
176                            mkTyConApp embed_tc [ty])
177             where ty = exprType expr
178
179       return . mk_sum $ map (mk_prod . map mk_embed) ess
180
181 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
182 mkFromPRepr scrut res_ty alts
183   = do
184       embed_dc <- builtin embedDataCon
185       sum_tcs  <- builtins sumTyCon
186       prod_tcs <- builtins prodTyCon
187
188       let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
189           un_sum expr ty bs
190             = do
191                 ps     <- mapM (newLocalVar FSLIT("p")) tys
192                 bodies <- sequence
193                         $ zipWith4 un_prod (map Var ps) tys vars rs
194                 return . Case expr (mkWildId ty) res_ty
195                        $ zipWith3 mk_alt sum_dcs ps bodies
196             where
197               (vars, rs) = unzip bs
198               tys        = splitFixedTyConApp sum_tc ty
199               sum_tc     = sum_tcs $ length bs
200               sum_dcs    = tyConDataCons sum_tc
201
202               mk_alt dc p body = (DataAlt dc, [p], body)
203
204           un_prod expr ty []    r = return r
205           un_prod expr ty [var] r = return $ un_embed expr ty var r
206           un_prod expr ty vars  r
207             = do
208                 xs <- mapM (newLocalVar FSLIT("x")) tys
209                 let body = foldr (\(e,t,v) r -> un_embed e t v r) r
210                          $ zip3 (map Var xs) tys vars
211                 return $ Case expr (mkWildId ty) res_ty
212                          [(DataAlt prod_dc, xs, body)]
213             where
214               tys       = splitFixedTyConApp prod_tc ty
215               prod_tc   = prod_tcs $ length vars
216               [prod_dc] = tyConDataCons prod_tc
217
218           un_embed expr ty var r
219             = Case expr (mkWildId ty) res_ty
220                 [(DataAlt embed_dc, [var], r)]
221
222       un_sum scrut (exprType scrut) alts
223
224 mkClosureType :: Type -> Type -> VM Type
225 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
226
227 mkClosureTypes :: [Type] -> Type -> VM Type
228 mkClosureTypes = mkBuiltinTyConApps closureTyCon
229
230 mkPReprType :: Type -> VM Type
231 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
232
233 mkPADictType :: Type -> VM Type
234 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
235
236 mkPArrayType :: Type -> VM Type
237 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
238
239 parrayReprTyCon :: Type -> VM (TyCon, [Type])
240 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
241
242 parrayReprDataCon :: Type -> VM (DataCon, [Type])
243 parrayReprDataCon ty
244   = do
245       (tc, arg_tys) <- parrayReprTyCon ty
246       let [dc] = tyConDataCons tc
247       return (dc, arg_tys)
248
249 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
250 mkVScrut (ve, le)
251   = do
252       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
253       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
254
255 prDictOfType :: Type -> VM CoreExpr
256 prDictOfType orig_ty
257   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
258   = do
259       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
260       prDFunApply (Var dfun) ty_args
261
262 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
263 prDFunApply dfun tys
264   = do
265       args <- mapM mkDFunArg arg_tys
266       return $ mkApps mono_dfun args
267   where
268     mono_dfun    = mkTyApps dfun tys
269     (arg_tys, _) = splitFunTys (exprType mono_dfun)
270
271 mkDFunArg :: Type -> VM CoreExpr
272 mkDFunArg ty
273   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
274
275   = let name = tyConName tycon
276
277         get_dict | name == paTyConName = paDictOfType
278                  | name == prTyConName = prDictOfType
279                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
280
281     in get_dict arg
282
283 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
284
285 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
286 prCoerce repr_tc args expr
287   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
288   = do
289       pr_tc <- builtin prTyCon
290
291       let co = mkAppCoercion (mkTyConApp pr_tc [])
292                              (mkSymCoercion (mkTyConApp arg_co args))
293
294       return $ mkCoerce co expr
295
296 paDictArgType :: TyVar -> VM (Maybe Type)
297 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
298   where
299     go ty k | Just k' <- kindView k = go ty k'
300     go ty (FunTy k1 k2)
301       = do
302           tv   <- newTyVar FSLIT("a") k1
303           mty1 <- go (TyVarTy tv) k1
304           case mty1 of
305             Just ty1 -> do
306                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
307                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
308             Nothing  -> go ty k2
309
310     go ty k
311       | isLiftedTypeKind k
312       = liftM Just (mkPADictType ty)
313
314     go ty k = return Nothing
315
316 paDictOfType :: Type -> VM CoreExpr
317 paDictOfType ty = paDictOfTyApp ty_fn ty_args
318   where
319     (ty_fn, ty_args) = splitAppTys ty
320
321 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
322 paDictOfTyApp ty_fn ty_args
323   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
324 paDictOfTyApp (TyVarTy tv) ty_args
325   = do
326       dfun <- maybeV (lookupTyVarPA tv)
327       paDFunApply dfun ty_args
328 paDictOfTyApp (TyConApp tc _) ty_args
329   = do
330       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
331       paDFunApply (Var dfun) ty_args
332 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
333
334 paDFunType :: TyCon -> VM Type
335 paDFunType tc
336   = do
337       margs <- mapM paDictArgType tvs
338       res   <- mkPADictType (mkTyConApp tc arg_tys)
339       return . mkForAllTys tvs
340              $ mkFunTys [arg | Just arg <- margs] res
341   where
342     tvs = tyConTyVars tc
343     arg_tys = mkTyVarTys tvs
344
345 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
346 paDFunApply dfun tys
347   = do
348       dicts <- mapM paDictOfType tys
349       return $ mkApps (mkTyApps dfun tys) dicts
350
351 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
352 paMethod method ty
353   = do
354       fn   <- builtin method
355       dict <- paDictOfType ty
356       return $ mkApps (Var fn) [Type ty, dict]
357
358 lengthPA :: CoreExpr -> VM CoreExpr
359 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
360   where
361     ty = splitPArrayTy (exprType x)
362
363 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
364 replicatePA len x = liftM (`mkApps` [len,x])
365                           (paMethod replicatePAVar (exprType x))
366
367 emptyPA :: Type -> VM CoreExpr
368 emptyPA = paMethod emptyPAVar
369
370 liftPA :: CoreExpr -> VM CoreExpr
371 liftPA x
372   = do
373       lc <- builtin liftingContext
374       replicatePA (Var lc) x
375
376 newLocalVVar :: FastString -> Type -> VM VVar
377 newLocalVVar fs vty
378   = do
379       lty <- mkPArrayType vty
380       vv  <- newLocalVar fs vty
381       lv  <- newLocalVar fs lty
382       return (vv,lv)
383
384 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
385 polyAbstract tvs p
386   = localV
387   $ do
388       mdicts <- mapM mk_dict_var tvs
389       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
390       p (mk_lams mdicts)
391   where
392     mk_dict_var tv = do
393                        r <- paDictArgType tv
394                        case r of
395                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
396                          Nothing -> return Nothing
397
398     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
399
400 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
401 polyApply expr tys
402   = do
403       dicts <- mapM paDictOfType tys
404       return $ expr `mkTyApps` tys `mkApps` dicts
405
406 polyVApply :: VExpr -> [Type] -> VM VExpr
407 polyVApply expr tys
408   = do
409       dicts <- mapM paDictOfType tys
410       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
411
412 hoistBinding :: Var -> CoreExpr -> VM ()
413 hoistBinding v e = updGEnv $ \env ->
414   env { global_bindings = (v,e) : global_bindings env }
415
416 hoistExpr :: FastString -> CoreExpr -> VM Var
417 hoistExpr fs expr
418   = do
419       var <- newLocalVar fs (exprType expr)
420       hoistBinding var expr
421       return var
422
423 hoistVExpr :: VExpr -> VM VVar
424 hoistVExpr (ve, le)
425   = do
426       fs <- getBindName
427       vv <- hoistExpr ('v' `consFS` fs) ve
428       lv <- hoistExpr ('l' `consFS` fs) le
429       return (vv, lv)
430
431 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
432 hoistPolyVExpr tvs p
433   = do
434       expr <- closedV . polyAbstract tvs $ \abstract ->
435               liftM (mapVect abstract) p
436       fn   <- hoistVExpr expr
437       polyVApply (vVar fn) (mkTyVarTys tvs)
438
439 takeHoisted :: VM [(Var, CoreExpr)]
440 takeHoisted
441   = do
442       env <- readGEnv id
443       setGEnv $ env { global_bindings = [] }
444       return $ global_bindings env
445
446 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
447 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
448   = do
449       dict <- paDictOfType env_ty
450       mkv  <- builtin mkClosureVar
451       mkl  <- builtin mkClosurePVar
452       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
453               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
454
455 mkClosureApp :: VExpr -> VExpr -> VM VExpr
456 mkClosureApp (vclo, lclo) (varg, larg)
457   = do
458       vapply <- builtin applyClosureVar
459       lapply <- builtin applyClosurePVar
460       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
461               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
462   where
463     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
464
465 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
466 buildClosures tvs vars [] res_ty mk_body
467   = mk_body
468 buildClosures tvs vars [arg_ty] res_ty mk_body
469   = buildClosure tvs vars arg_ty res_ty mk_body
470 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
471   = do
472       res_ty' <- mkClosureTypes arg_tys res_ty
473       arg <- newLocalVVar FSLIT("x") arg_ty
474       buildClosure tvs vars arg_ty res_ty'
475         . hoistPolyVExpr tvs
476         $ do
477             lc <- builtin liftingContext
478             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
479             return $ vLams lc (vars ++ [arg]) clo
480
481 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
482 --   where
483 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
484 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
485 --
486 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
487 buildClosure tvs vars arg_ty res_ty mk_body
488   = do
489       (env_ty, env, bind) <- buildEnv vars
490       env_bndr <- newLocalVVar FSLIT("env") env_ty
491       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
492
493       fn <- hoistPolyVExpr tvs
494           $ do
495               lc    <- builtin liftingContext
496               body  <- mk_body
497               body' <- bind (vVar env_bndr)
498                             (vVarApps lc body (vars ++ [arg_bndr]))
499               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
500
501       mkClosure arg_ty res_ty env_ty fn env
502
503 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
504 buildEnv vvs
505   = do
506       lc <- builtin liftingContext
507       let (ty, venv, vbind) = mkVectEnv tys vs
508       (lenv, lbind) <- mkLiftEnv lc tys ls
509       return (ty, (venv, lenv),
510               \(venv,lenv) (vbody,lbody) ->
511               do
512                 let vbody' = vbind venv vbody
513                 lbody' <- lbind lenv lbody
514                 return (vbody', lbody'))
515   where
516     (vs,ls) = unzip vvs
517     tys     = map idType vs
518
519 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
520 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
521 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
522 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
523                         \env body -> Case env (mkWildId ty) (exprType body)
524                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
525   where
526     ty = mkCoreTupTy tys
527
528 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
529 mkLiftEnv lc [ty] [v]
530   = return (Var v, \env body ->
531                    do
532                      len <- lengthPA (Var v)
533                      return . Let (NonRec v env)
534                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
535
536 -- NOTE: this transparently deals with empty environments
537 mkLiftEnv lc tys vs
538   = do
539       (env_tc, env_tyargs) <- parrayReprTyCon vty
540       let [env_con] = tyConDataCons env_tc
541           
542           env = Var (dataConWrapId env_con)
543                 `mkTyApps`  env_tyargs
544                 `mkVarApps` (lc : vs)
545
546           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
547                           in
548                           return $ Case scrut (mkWildId (exprType scrut))
549                                         (exprType body)
550                                         [(DataAlt env_con, lc : bndrs, body)]
551       return (env, bind)
552   where
553     vty = mkCoreTupTy tys
554
555     bndrs | null vs   = [mkWildId unitTy]
556           | otherwise = vs
557