Modify generation of PR dictionaries for new scheme
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   TyConRepr(..), mkTyConRepr,
8   mkToArrPRepr, mkFromArrPRepr,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon, prCoerce,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import Name               ( Name )
38 import PrelNames
39 import TysWiredIn
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44 import Maybes             ( orElse )
45
46 import Data.List             ( zipWith4 )
47 import Control.Monad         ( liftM, liftM2, zipWithM_ )
48
49 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
50 collectAnnTypeArgs expr = go expr []
51   where
52     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
53     go e                             tys = (e, tys)
54
55 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnTypeBinders expr = go [] expr
57   where
58     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
59     go bs e                           = (reverse bs, e)
60
61 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnValBinders expr = go [] expr
63   where
64     go bs (_, AnnLam b e) | isId b = go (b:bs) e
65     go bs e                        = (reverse bs, e)
66
67 isAnnTypeArg :: AnnExpr b ann -> Bool
68 isAnnTypeArg (_, AnnType t) = True
69 isAnnTypeArg _              = False
70
71 mkDataConTag :: DataCon -> CoreExpr
72 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
73
74 splitUnTy :: String -> Name -> Type -> Type
75 splitUnTy s name ty
76   | Just (tc, [ty']) <- splitTyConApp_maybe ty
77   , tyConName tc == name
78   = ty'
79
80   | otherwise = pprPanic s (ppr ty)
81
82 splitBinTy :: String -> Name -> Type -> (Type, Type)
83 splitBinTy s name ty
84   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
85   , tyConName tc == name
86   = (ty1, ty2)
87
88   | otherwise = pprPanic s (ppr ty)
89
90 splitFixedTyConApp :: TyCon -> Type -> [Type]
91 splitFixedTyConApp tc ty
92   | Just (tc', tys) <- splitTyConApp_maybe ty
93   , tc == tc'
94   = tys
95
96   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
97
98 splitClosureTy :: Type -> (Type, Type)
99 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
100
101 splitPArrayTy :: Type -> Type
102 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
103
104 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
105 mkBuiltinTyConApp get_tc tys
106   = do
107       tc <- builtin get_tc
108       return $ mkTyConApp tc tys
109
110 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
111 mkBuiltinTyConApps get_tc tys ty
112   = do
113       tc <- builtin get_tc
114       return $ foldr (mk tc) ty tys
115   where
116     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
117
118 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
119 mkBuiltinTyConApps1 get_tc dft [] = return dft
120 mkBuiltinTyConApps1 get_tc dft tys
121   = do
122       tc <- builtin get_tc
123       case tys of
124         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
125         _  -> return $ foldr1 (mk tc) tys
126   where
127     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
128
129 data TyConRepr = TyConRepr {
130                    repr_tyvars         :: [TyVar]
131                  , repr_tys            :: [[Type]]
132
133                  , repr_prod_tycons    :: [Maybe TyCon]
134                  , repr_prod_data_cons :: [Maybe DataCon]
135                  , repr_prod_tys       :: [Type]
136                  , repr_sum_tycon      :: Maybe TyCon
137                  , repr_sum_data_cons  :: [DataCon]
138                  , repr_type           :: Type
139                  }
140
141 mkTyConRepr :: TyCon -> VM TyConRepr
142 mkTyConRepr vect_tc
143   = do
144       prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
145       let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
146       sum_tycon   <- mk_tycon sumTyCon prod_tys
147
148       return $ TyConRepr {
149                  repr_tyvars         = tyvars
150                , repr_tys            = rep_tys
151
152                , repr_prod_tycons    = prod_tycons
153                , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
154                , repr_prod_tys       = prod_tys
155                , repr_sum_tycon      = sum_tycon
156                , repr_sum_data_cons  = fmap tyConDataCons sum_tycon `orElse` []
157                , repr_type           = mk_tc_app_maybe sum_tycon prod_tys
158                }
159   where
160     tyvars = tyConTyVars vect_tc
161     data_cons = tyConDataCons vect_tc
162     rep_tys   = map dataConRepArgTys data_cons
163
164     mk_tycon get_tc tys
165       | n > 1     = builtin (Just . get_tc n)
166       | otherwise = return Nothing
167       where n = length tys
168
169     mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
170
171     mk_tc_app_maybe Nothing   []   = unitTy
172     mk_tc_app_maybe Nothing   [ty] = ty
173     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
174
175 mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
176 mkToArrPRepr len sel ess
177   = do
178       let mk_sum [(expr, ty)] = return (expr, ty)
179           mk_sum es
180             = do
181                 sum_tc <- builtin . sumTyCon $ length es
182                 (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
183                 let [sum_rdc] = tyConDataCons sum_rtc
184
185                 return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
186                         mkTyConApp sum_tc tys)
187             where
188               (exprs, tys) = unzip es
189
190           mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
191           mk_prod exprs
192             = do
193                 prod_tc <- builtin . prodTyCon $ length exprs
194                 (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
195                 let [prod_rdc] = tyConDataCons prod_rtc
196
197                 return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
198                         mkTyConApp prod_tc tys)
199             where
200               tys = map (splitPArrayTy . exprType) exprs
201
202       liftM fst (mk_sum =<< mapM mk_prod ess)
203
204 mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
205                -> VM CoreExpr
206 mkFromArrPRepr scrut res_ty len sel vars res
207   = return (Var unitDataConId)
208
209 mkClosureType :: Type -> Type -> VM Type
210 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
211
212 mkClosureTypes :: [Type] -> Type -> VM Type
213 mkClosureTypes = mkBuiltinTyConApps closureTyCon
214
215 mkPReprType :: Type -> VM Type
216 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
217
218 mkPADictType :: Type -> VM Type
219 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
220
221 mkPArrayType :: Type -> VM Type
222 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
223
224 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
225 parrayCoerce repr_tc args expr
226   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
227   = do
228       parray <- builtin parrayTyCon
229
230       let co = mkAppCoercion (mkTyConApp parray [])
231                              (mkSymCoercion (mkTyConApp arg_co args))
232
233       return $ mkCoerce co expr
234
235 parrayReprTyCon :: Type -> VM (TyCon, [Type])
236 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
237
238 parrayReprDataCon :: Type -> VM (DataCon, [Type])
239 parrayReprDataCon ty
240   = do
241       (tc, arg_tys) <- parrayReprTyCon ty
242       let [dc] = tyConDataCons tc
243       return (dc, arg_tys)
244
245 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
246 mkVScrut (ve, le)
247   = do
248       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
249       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
250
251 prDFunOfTyCon :: TyCon -> VM CoreExpr
252 prDFunOfTyCon tycon
253   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
254
255 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
256 prCoerce repr_tc args expr
257   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
258   = do
259       pr_tc <- builtin prTyCon
260
261       let co = mkAppCoercion (mkTyConApp pr_tc [])
262                              (mkSymCoercion (mkTyConApp arg_co args))
263
264       return $ mkCoerce co expr
265
266 paDictArgType :: TyVar -> VM (Maybe Type)
267 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
268   where
269     go ty k | Just k' <- kindView k = go ty k'
270     go ty (FunTy k1 k2)
271       = do
272           tv   <- newTyVar FSLIT("a") k1
273           mty1 <- go (TyVarTy tv) k1
274           case mty1 of
275             Just ty1 -> do
276                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
277                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
278             Nothing  -> go ty k2
279
280     go ty k
281       | isLiftedTypeKind k
282       = liftM Just (mkPADictType ty)
283
284     go ty k = return Nothing
285
286 paDictOfType :: Type -> VM CoreExpr
287 paDictOfType ty = paDictOfTyApp ty_fn ty_args
288   where
289     (ty_fn, ty_args) = splitAppTys ty
290
291 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
292 paDictOfTyApp ty_fn ty_args
293   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
294 paDictOfTyApp (TyVarTy tv) ty_args
295   = do
296       dfun <- maybeV (lookupTyVarPA tv)
297       paDFunApply dfun ty_args
298 paDictOfTyApp (TyConApp tc _) ty_args
299   = do
300       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
301       paDFunApply (Var dfun) ty_args
302 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
303
304 paDFunType :: TyCon -> VM Type
305 paDFunType tc
306   = do
307       margs <- mapM paDictArgType tvs
308       res   <- mkPADictType (mkTyConApp tc arg_tys)
309       return . mkForAllTys tvs
310              $ mkFunTys [arg | Just arg <- margs] res
311   where
312     tvs = tyConTyVars tc
313     arg_tys = mkTyVarTys tvs
314
315 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
316 paDFunApply dfun tys
317   = do
318       dicts <- mapM paDictOfType tys
319       return $ mkApps (mkTyApps dfun tys) dicts
320
321 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
322 paMethod method ty
323   = do
324       fn   <- builtin method
325       dict <- paDictOfType ty
326       return $ mkApps (Var fn) [Type ty, dict]
327
328 mkPR :: Type -> VM CoreExpr
329 mkPR = paMethod mkPRVar
330
331 lengthPA :: CoreExpr -> VM CoreExpr
332 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
333   where
334     ty = splitPArrayTy (exprType x)
335
336 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
337 replicatePA len x = liftM (`mkApps` [len,x])
338                           (paMethod replicatePAVar (exprType x))
339
340 emptyPA :: Type -> VM CoreExpr
341 emptyPA = paMethod emptyPAVar
342
343 liftPA :: CoreExpr -> VM CoreExpr
344 liftPA x
345   = do
346       lc <- builtin liftingContext
347       replicatePA (Var lc) x
348
349 newLocalVVar :: FastString -> Type -> VM VVar
350 newLocalVVar fs vty
351   = do
352       lty <- mkPArrayType vty
353       vv  <- newLocalVar fs vty
354       lv  <- newLocalVar fs lty
355       return (vv,lv)
356
357 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
358 polyAbstract tvs p
359   = localV
360   $ do
361       mdicts <- mapM mk_dict_var tvs
362       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
363       p (mk_lams mdicts)
364   where
365     mk_dict_var tv = do
366                        r <- paDictArgType tv
367                        case r of
368                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
369                          Nothing -> return Nothing
370
371     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
372
373 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
374 polyApply expr tys
375   = do
376       dicts <- mapM paDictOfType tys
377       return $ expr `mkTyApps` tys `mkApps` dicts
378
379 polyVApply :: VExpr -> [Type] -> VM VExpr
380 polyVApply expr tys
381   = do
382       dicts <- mapM paDictOfType tys
383       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
384
385 hoistBinding :: Var -> CoreExpr -> VM ()
386 hoistBinding v e = updGEnv $ \env ->
387   env { global_bindings = (v,e) : global_bindings env }
388
389 hoistExpr :: FastString -> CoreExpr -> VM Var
390 hoistExpr fs expr
391   = do
392       var <- newLocalVar fs (exprType expr)
393       hoistBinding var expr
394       return var
395
396 hoistVExpr :: VExpr -> VM VVar
397 hoistVExpr (ve, le)
398   = do
399       fs <- getBindName
400       vv <- hoistExpr ('v' `consFS` fs) ve
401       lv <- hoistExpr ('l' `consFS` fs) le
402       return (vv, lv)
403
404 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
405 hoistPolyVExpr tvs p
406   = do
407       expr <- closedV . polyAbstract tvs $ \abstract ->
408               liftM (mapVect abstract) p
409       fn   <- hoistVExpr expr
410       polyVApply (vVar fn) (mkTyVarTys tvs)
411
412 takeHoisted :: VM [(Var, CoreExpr)]
413 takeHoisted
414   = do
415       env <- readGEnv id
416       setGEnv $ env { global_bindings = [] }
417       return $ global_bindings env
418
419 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
420 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
421   = do
422       dict <- paDictOfType env_ty
423       mkv  <- builtin mkClosureVar
424       mkl  <- builtin mkClosurePVar
425       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
426               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
427
428 mkClosureApp :: VExpr -> VExpr -> VM VExpr
429 mkClosureApp (vclo, lclo) (varg, larg)
430   = do
431       vapply <- builtin applyClosureVar
432       lapply <- builtin applyClosurePVar
433       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
434               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
435   where
436     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
437
438 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
439 buildClosures tvs vars [] res_ty mk_body
440   = mk_body
441 buildClosures tvs vars [arg_ty] res_ty mk_body
442   = buildClosure tvs vars arg_ty res_ty mk_body
443 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
444   = do
445       res_ty' <- mkClosureTypes arg_tys res_ty
446       arg <- newLocalVVar FSLIT("x") arg_ty
447       buildClosure tvs vars arg_ty res_ty'
448         . hoistPolyVExpr tvs
449         $ do
450             lc <- builtin liftingContext
451             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
452             return $ vLams lc (vars ++ [arg]) clo
453
454 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
455 --   where
456 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
457 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
458 --
459 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
460 buildClosure tvs vars arg_ty res_ty mk_body
461   = do
462       (env_ty, env, bind) <- buildEnv vars
463       env_bndr <- newLocalVVar FSLIT("env") env_ty
464       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
465
466       fn <- hoistPolyVExpr tvs
467           $ do
468               lc    <- builtin liftingContext
469               body  <- mk_body
470               body' <- bind (vVar env_bndr)
471                             (vVarApps lc body (vars ++ [arg_bndr]))
472               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
473
474       mkClosure arg_ty res_ty env_ty fn env
475
476 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
477 buildEnv vvs
478   = do
479       lc <- builtin liftingContext
480       let (ty, venv, vbind) = mkVectEnv tys vs
481       (lenv, lbind) <- mkLiftEnv lc tys ls
482       return (ty, (venv, lenv),
483               \(venv,lenv) (vbody,lbody) ->
484               do
485                 let vbody' = vbind venv vbody
486                 lbody' <- lbind lenv lbody
487                 return (vbody', lbody'))
488   where
489     (vs,ls) = unzip vvs
490     tys     = map idType vs
491
492 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
493 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
494 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
495 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
496                         \env body -> Case env (mkWildId ty) (exprType body)
497                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
498   where
499     ty = mkCoreTupTy tys
500
501 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
502 mkLiftEnv lc [ty] [v]
503   = return (Var v, \env body ->
504                    do
505                      len <- lengthPA (Var v)
506                      return . Let (NonRec v env)
507                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
508
509 -- NOTE: this transparently deals with empty environments
510 mkLiftEnv lc tys vs
511   = do
512       (env_tc, env_tyargs) <- parrayReprTyCon vty
513       let [env_con] = tyConDataCons env_tc
514           
515           env = Var (dataConWrapId env_con)
516                 `mkTyApps`  env_tyargs
517                 `mkVarApps` (lc : vs)
518
519           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
520                           in
521                           return $ Case scrut (mkWildId (exprType scrut))
522                                         (exprType body)
523                                         [(DataAlt env_con, lc : bndrs, body)]
524       return (env, bind)
525   where
526     vty = mkCoreTupTy tys
527
528     bndrs | null vs   = [mkWildId unitTy]
529           | otherwise = vs
530