77e037f4779d8a4905de463360b0f7772dfb82a1
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   TyConRepr(..), mkTyConRepr,
8   mkToArrPRepr, mkFromArrPRepr,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDictOfType, prCoerce,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import Name               ( Name )
38 import PrelNames
39 import TysWiredIn
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75   | Just (tc, [ty']) <- splitTyConApp_maybe ty
76   , tyConName tc == name
77   = ty'
78
79   | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84   , tyConName tc == name
85   = (ty1, ty2)
86
87   | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91   | Just (tc', tys) <- splitTyConApp_maybe ty
92   , tc == tc'
93   = tys
94
95   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105   = do
106       tc <- builtin get_tc
107       return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111   = do
112       tc <- builtin get_tc
113       return $ foldr (mk tc) ty tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120   = do
121       tc <- builtin get_tc
122       case tys of
123         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124         _  -> return $ foldr1 (mk tc) tys
125   where
126     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 data TyConRepr = TyConRepr {
129                    repr_tyvars      :: [TyVar]
130                  , repr_tys         :: [[Type]]
131
132                  , repr_prod_tycons :: [Maybe TyCon]
133                  , repr_prod_tys    :: [Type]
134                  , repr_sum_tycon   :: Maybe TyCon
135                  , repr_type        :: Type
136                  }
137
138 mkTyConRepr :: TyCon -> VM TyConRepr
139 mkTyConRepr vect_tc
140   = do
141       prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
142       let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
143       sum_tycon   <- mk_tycon sumTyCon prod_tys
144
145       return $ TyConRepr {
146                  repr_tyvars      = tyvars
147                , repr_tys         = rep_tys
148
149                , repr_prod_tycons = prod_tycons
150                , repr_prod_tys    = prod_tys
151                , repr_sum_tycon   = sum_tycon
152                , repr_type        = mk_tc_app_maybe sum_tycon prod_tys
153                }
154   where
155     tyvars = tyConTyVars vect_tc
156     data_cons = tyConDataCons vect_tc
157     rep_tys   = map dataConRepArgTys data_cons
158
159     mk_tycon get_tc tys
160       | n > 1     = builtin (Just . get_tc n)
161       | otherwise = return Nothing
162       where n = length tys
163
164     mk_tc_app_maybe Nothing   []   = unitTy
165     mk_tc_app_maybe Nothing   [ty] = ty
166     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
167
168 mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
169 mkToArrPRepr len sel ess
170   = do
171       let mk_sum [(expr, ty)] = return (expr, ty)
172           mk_sum es
173             = do
174                 sum_tc <- builtin . sumTyCon $ length es
175                 (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
176                 let [sum_rdc] = tyConDataCons sum_rtc
177
178                 return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
179                         mkTyConApp sum_tc tys)
180             where
181               (exprs, tys) = unzip es
182
183           mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
184           mk_prod exprs
185             = do
186                 prod_tc <- builtin . prodTyCon $ length exprs
187                 (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
188                 let [prod_rdc] = tyConDataCons prod_rtc
189
190                 return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
191                         mkTyConApp prod_tc tys)
192             where
193               tys = map (splitPArrayTy . exprType) exprs
194
195       liftM fst (mk_sum =<< mapM mk_prod ess)
196
197 mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
198                -> VM CoreExpr
199 mkFromArrPRepr scrut res_ty len sel vars res
200   = return (Var unitDataConId)
201
202 mkClosureType :: Type -> Type -> VM Type
203 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
204
205 mkClosureTypes :: [Type] -> Type -> VM Type
206 mkClosureTypes = mkBuiltinTyConApps closureTyCon
207
208 mkPReprType :: Type -> VM Type
209 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
210
211 mkPADictType :: Type -> VM Type
212 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
213
214 mkPArrayType :: Type -> VM Type
215 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
216
217 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
218 parrayCoerce repr_tc args expr
219   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
220   = do
221       parray <- builtin parrayTyCon
222
223       let co = mkAppCoercion (mkTyConApp parray [])
224                              (mkSymCoercion (mkTyConApp arg_co args))
225
226       return $ mkCoerce co expr
227
228 parrayReprTyCon :: Type -> VM (TyCon, [Type])
229 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
230
231 parrayReprDataCon :: Type -> VM (DataCon, [Type])
232 parrayReprDataCon ty
233   = do
234       (tc, arg_tys) <- parrayReprTyCon ty
235       let [dc] = tyConDataCons tc
236       return (dc, arg_tys)
237
238 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
239 mkVScrut (ve, le)
240   = do
241       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
242       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
243
244 prDictOfType :: Type -> VM CoreExpr
245 prDictOfType orig_ty
246   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
247   = do
248       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
249       prDFunApply (Var dfun) ty_args
250
251 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
252 prDFunApply dfun tys
253   = do
254       args <- mapM mkDFunArg arg_tys
255       return $ mkApps mono_dfun args
256   where
257     mono_dfun    = mkTyApps dfun tys
258     (arg_tys, _) = splitFunTys (exprType mono_dfun)
259
260 mkDFunArg :: Type -> VM CoreExpr
261 mkDFunArg ty
262   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
263
264   = let name = tyConName tycon
265
266         get_dict | name == paTyConName = paDictOfType
267                  | name == prTyConName = prDictOfType
268                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
269
270     in get_dict arg
271
272 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
273
274 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
275 prCoerce repr_tc args expr
276   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
277   = do
278       pr_tc <- builtin prTyCon
279
280       let co = mkAppCoercion (mkTyConApp pr_tc [])
281                              (mkSymCoercion (mkTyConApp arg_co args))
282
283       return $ mkCoerce co expr
284
285 paDictArgType :: TyVar -> VM (Maybe Type)
286 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
287   where
288     go ty k | Just k' <- kindView k = go ty k'
289     go ty (FunTy k1 k2)
290       = do
291           tv   <- newTyVar FSLIT("a") k1
292           mty1 <- go (TyVarTy tv) k1
293           case mty1 of
294             Just ty1 -> do
295                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
296                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
297             Nothing  -> go ty k2
298
299     go ty k
300       | isLiftedTypeKind k
301       = liftM Just (mkPADictType ty)
302
303     go ty k = return Nothing
304
305 paDictOfType :: Type -> VM CoreExpr
306 paDictOfType ty = paDictOfTyApp ty_fn ty_args
307   where
308     (ty_fn, ty_args) = splitAppTys ty
309
310 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
311 paDictOfTyApp ty_fn ty_args
312   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
313 paDictOfTyApp (TyVarTy tv) ty_args
314   = do
315       dfun <- maybeV (lookupTyVarPA tv)
316       paDFunApply dfun ty_args
317 paDictOfTyApp (TyConApp tc _) ty_args
318   = do
319       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
320       paDFunApply (Var dfun) ty_args
321 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
322
323 paDFunType :: TyCon -> VM Type
324 paDFunType tc
325   = do
326       margs <- mapM paDictArgType tvs
327       res   <- mkPADictType (mkTyConApp tc arg_tys)
328       return . mkForAllTys tvs
329              $ mkFunTys [arg | Just arg <- margs] res
330   where
331     tvs = tyConTyVars tc
332     arg_tys = mkTyVarTys tvs
333
334 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
335 paDFunApply dfun tys
336   = do
337       dicts <- mapM paDictOfType tys
338       return $ mkApps (mkTyApps dfun tys) dicts
339
340 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
341 paMethod method ty
342   = do
343       fn   <- builtin method
344       dict <- paDictOfType ty
345       return $ mkApps (Var fn) [Type ty, dict]
346
347 mkPR :: Type -> VM CoreExpr
348 mkPR = paMethod mkPRVar
349
350 lengthPA :: CoreExpr -> VM CoreExpr
351 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
352   where
353     ty = splitPArrayTy (exprType x)
354
355 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
356 replicatePA len x = liftM (`mkApps` [len,x])
357                           (paMethod replicatePAVar (exprType x))
358
359 emptyPA :: Type -> VM CoreExpr
360 emptyPA = paMethod emptyPAVar
361
362 liftPA :: CoreExpr -> VM CoreExpr
363 liftPA x
364   = do
365       lc <- builtin liftingContext
366       replicatePA (Var lc) x
367
368 newLocalVVar :: FastString -> Type -> VM VVar
369 newLocalVVar fs vty
370   = do
371       lty <- mkPArrayType vty
372       vv  <- newLocalVar fs vty
373       lv  <- newLocalVar fs lty
374       return (vv,lv)
375
376 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
377 polyAbstract tvs p
378   = localV
379   $ do
380       mdicts <- mapM mk_dict_var tvs
381       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
382       p (mk_lams mdicts)
383   where
384     mk_dict_var tv = do
385                        r <- paDictArgType tv
386                        case r of
387                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
388                          Nothing -> return Nothing
389
390     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
391
392 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
393 polyApply expr tys
394   = do
395       dicts <- mapM paDictOfType tys
396       return $ expr `mkTyApps` tys `mkApps` dicts
397
398 polyVApply :: VExpr -> [Type] -> VM VExpr
399 polyVApply expr tys
400   = do
401       dicts <- mapM paDictOfType tys
402       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
403
404 hoistBinding :: Var -> CoreExpr -> VM ()
405 hoistBinding v e = updGEnv $ \env ->
406   env { global_bindings = (v,e) : global_bindings env }
407
408 hoistExpr :: FastString -> CoreExpr -> VM Var
409 hoistExpr fs expr
410   = do
411       var <- newLocalVar fs (exprType expr)
412       hoistBinding var expr
413       return var
414
415 hoistVExpr :: VExpr -> VM VVar
416 hoistVExpr (ve, le)
417   = do
418       fs <- getBindName
419       vv <- hoistExpr ('v' `consFS` fs) ve
420       lv <- hoistExpr ('l' `consFS` fs) le
421       return (vv, lv)
422
423 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
424 hoistPolyVExpr tvs p
425   = do
426       expr <- closedV . polyAbstract tvs $ \abstract ->
427               liftM (mapVect abstract) p
428       fn   <- hoistVExpr expr
429       polyVApply (vVar fn) (mkTyVarTys tvs)
430
431 takeHoisted :: VM [(Var, CoreExpr)]
432 takeHoisted
433   = do
434       env <- readGEnv id
435       setGEnv $ env { global_bindings = [] }
436       return $ global_bindings env
437
438 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
439 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
440   = do
441       dict <- paDictOfType env_ty
442       mkv  <- builtin mkClosureVar
443       mkl  <- builtin mkClosurePVar
444       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
445               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
446
447 mkClosureApp :: VExpr -> VExpr -> VM VExpr
448 mkClosureApp (vclo, lclo) (varg, larg)
449   = do
450       vapply <- builtin applyClosureVar
451       lapply <- builtin applyClosurePVar
452       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
453               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
454   where
455     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
456
457 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
458 buildClosures tvs vars [] res_ty mk_body
459   = mk_body
460 buildClosures tvs vars [arg_ty] res_ty mk_body
461   = buildClosure tvs vars arg_ty res_ty mk_body
462 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
463   = do
464       res_ty' <- mkClosureTypes arg_tys res_ty
465       arg <- newLocalVVar FSLIT("x") arg_ty
466       buildClosure tvs vars arg_ty res_ty'
467         . hoistPolyVExpr tvs
468         $ do
469             lc <- builtin liftingContext
470             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
471             return $ vLams lc (vars ++ [arg]) clo
472
473 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
474 --   where
475 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
476 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
477 --
478 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
479 buildClosure tvs vars arg_ty res_ty mk_body
480   = do
481       (env_ty, env, bind) <- buildEnv vars
482       env_bndr <- newLocalVVar FSLIT("env") env_ty
483       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
484
485       fn <- hoistPolyVExpr tvs
486           $ do
487               lc    <- builtin liftingContext
488               body  <- mk_body
489               body' <- bind (vVar env_bndr)
490                             (vVarApps lc body (vars ++ [arg_bndr]))
491               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
492
493       mkClosure arg_ty res_ty env_ty fn env
494
495 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
496 buildEnv vvs
497   = do
498       lc <- builtin liftingContext
499       let (ty, venv, vbind) = mkVectEnv tys vs
500       (lenv, lbind) <- mkLiftEnv lc tys ls
501       return (ty, (venv, lenv),
502               \(venv,lenv) (vbody,lbody) ->
503               do
504                 let vbody' = vbind venv vbody
505                 lbody' <- lbind lenv lbody
506                 return (vbody', lbody'))
507   where
508     (vs,ls) = unzip vvs
509     tys     = map idType vs
510
511 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
512 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
513 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
514 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
515                         \env body -> Case env (mkWildId ty) (exprType body)
516                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
517   where
518     ty = mkCoreTupTy tys
519
520 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
521 mkLiftEnv lc [ty] [v]
522   = return (Var v, \env body ->
523                    do
524                      len <- lengthPA (Var v)
525                      return . Let (NonRec v env)
526                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
527
528 -- NOTE: this transparently deals with empty environments
529 mkLiftEnv lc tys vs
530   = do
531       (env_tc, env_tyargs) <- parrayReprTyCon vty
532       let [env_con] = tyConDataCons env_tc
533           
534           env = Var (dataConWrapId env_con)
535                 `mkTyApps`  env_tyargs
536                 `mkVarApps` (lc : vs)
537
538           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
539                           in
540                           return $ Case scrut (mkWildId (exprType scrut))
541                                         (exprType body)
542                                         [(DataAlt env_con, lc : bndrs, body)]
543       return (env, bind)
544   where
545     vty = mkCoreTupTy tys
546
547     bndrs | null vs   = [mkWildId unitTy]
548           | otherwise = vs
549