Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   TyConRepr(..), mkTyConRepr,
8   mkToArrPRepr, mkFromArrPRepr,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDictOfType, prCoerce,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import Name               ( Name )
38 import PrelNames
39 import TysWiredIn
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44 import Maybes             ( orElse )
45
46 import Data.List             ( zipWith4 )
47 import Control.Monad         ( liftM, liftM2, zipWithM_ )
48
49 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
50 collectAnnTypeArgs expr = go expr []
51   where
52     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
53     go e                             tys = (e, tys)
54
55 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnTypeBinders expr = go [] expr
57   where
58     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
59     go bs e                           = (reverse bs, e)
60
61 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnValBinders expr = go [] expr
63   where
64     go bs (_, AnnLam b e) | isId b = go (b:bs) e
65     go bs e                        = (reverse bs, e)
66
67 isAnnTypeArg :: AnnExpr b ann -> Bool
68 isAnnTypeArg (_, AnnType t) = True
69 isAnnTypeArg _              = False
70
71 mkDataConTag :: DataCon -> CoreExpr
72 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
73
74 splitUnTy :: String -> Name -> Type -> Type
75 splitUnTy s name ty
76   | Just (tc, [ty']) <- splitTyConApp_maybe ty
77   , tyConName tc == name
78   = ty'
79
80   | otherwise = pprPanic s (ppr ty)
81
82 splitBinTy :: String -> Name -> Type -> (Type, Type)
83 splitBinTy s name ty
84   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
85   , tyConName tc == name
86   = (ty1, ty2)
87
88   | otherwise = pprPanic s (ppr ty)
89
90 splitFixedTyConApp :: TyCon -> Type -> [Type]
91 splitFixedTyConApp tc ty
92   | Just (tc', tys) <- splitTyConApp_maybe ty
93   , tc == tc'
94   = tys
95
96   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
97
98 splitClosureTy :: Type -> (Type, Type)
99 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
100
101 splitPArrayTy :: Type -> Type
102 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
103
104 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
105 mkBuiltinTyConApp get_tc tys
106   = do
107       tc <- builtin get_tc
108       return $ mkTyConApp tc tys
109
110 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
111 mkBuiltinTyConApps get_tc tys ty
112   = do
113       tc <- builtin get_tc
114       return $ foldr (mk tc) ty tys
115   where
116     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
117
118 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
119 mkBuiltinTyConApps1 get_tc dft [] = return dft
120 mkBuiltinTyConApps1 get_tc dft tys
121   = do
122       tc <- builtin get_tc
123       case tys of
124         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
125         _  -> return $ foldr1 (mk tc) tys
126   where
127     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
128
129 data TyConRepr = TyConRepr {
130                    repr_tyvars         :: [TyVar]
131                  , repr_tys            :: [[Type]]
132
133                  , repr_prod_tycons    :: [Maybe TyCon]
134                  , repr_prod_data_cons :: [Maybe DataCon]
135                  , repr_prod_tys       :: [Type]
136                  , repr_sum_tycon      :: Maybe TyCon
137                  , repr_sum_data_cons  :: [DataCon]
138                  , repr_type           :: Type
139                  }
140
141 mkTyConRepr :: TyCon -> VM TyConRepr
142 mkTyConRepr vect_tc
143   = do
144       prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
145       let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
146       sum_tycon   <- mk_tycon sumTyCon prod_tys
147
148       return $ TyConRepr {
149                  repr_tyvars         = tyvars
150                , repr_tys            = rep_tys
151
152                , repr_prod_tycons    = prod_tycons
153                , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
154                , repr_prod_tys       = prod_tys
155                , repr_sum_tycon      = sum_tycon
156                , repr_sum_data_cons  = fmap tyConDataCons sum_tycon `orElse` []
157                , repr_type           = mk_tc_app_maybe sum_tycon prod_tys
158                }
159   where
160     tyvars = tyConTyVars vect_tc
161     data_cons = tyConDataCons vect_tc
162     rep_tys   = map dataConRepArgTys data_cons
163
164     mk_tycon get_tc tys
165       | n > 1     = builtin (Just . get_tc n)
166       | otherwise = return Nothing
167       where n = length tys
168
169     mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
170
171     mk_tc_app_maybe Nothing   []   = unitTy
172     mk_tc_app_maybe Nothing   [ty] = ty
173     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
174
175 mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
176 mkToArrPRepr len sel ess
177   = do
178       let mk_sum [(expr, ty)] = return (expr, ty)
179           mk_sum es
180             = do
181                 sum_tc <- builtin . sumTyCon $ length es
182                 (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
183                 let [sum_rdc] = tyConDataCons sum_rtc
184
185                 return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
186                         mkTyConApp sum_tc tys)
187             where
188               (exprs, tys) = unzip es
189
190           mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
191           mk_prod exprs
192             = do
193                 prod_tc <- builtin . prodTyCon $ length exprs
194                 (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
195                 let [prod_rdc] = tyConDataCons prod_rtc
196
197                 return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
198                         mkTyConApp prod_tc tys)
199             where
200               tys = map (splitPArrayTy . exprType) exprs
201
202       liftM fst (mk_sum =<< mapM mk_prod ess)
203
204 mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
205                -> VM CoreExpr
206 mkFromArrPRepr scrut res_ty len sel vars res
207   = return (Var unitDataConId)
208
209 mkClosureType :: Type -> Type -> VM Type
210 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
211
212 mkClosureTypes :: [Type] -> Type -> VM Type
213 mkClosureTypes = mkBuiltinTyConApps closureTyCon
214
215 mkPReprType :: Type -> VM Type
216 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
217
218 mkPADictType :: Type -> VM Type
219 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
220
221 mkPArrayType :: Type -> VM Type
222 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
223
224 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
225 parrayCoerce repr_tc args expr
226   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
227   = do
228       parray <- builtin parrayTyCon
229
230       let co = mkAppCoercion (mkTyConApp parray [])
231                              (mkSymCoercion (mkTyConApp arg_co args))
232
233       return $ mkCoerce co expr
234
235 parrayReprTyCon :: Type -> VM (TyCon, [Type])
236 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
237
238 parrayReprDataCon :: Type -> VM (DataCon, [Type])
239 parrayReprDataCon ty
240   = do
241       (tc, arg_tys) <- parrayReprTyCon ty
242       let [dc] = tyConDataCons tc
243       return (dc, arg_tys)
244
245 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
246 mkVScrut (ve, le)
247   = do
248       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
249       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
250
251 prDictOfType :: Type -> VM CoreExpr
252 prDictOfType orig_ty
253   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
254   = do
255       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
256       prDFunApply (Var dfun) ty_args
257
258 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
259 prDFunApply dfun tys
260   = do
261       args <- mapM mkDFunArg arg_tys
262       return $ mkApps mono_dfun args
263   where
264     mono_dfun    = mkTyApps dfun tys
265     (arg_tys, _) = splitFunTys (exprType mono_dfun)
266
267 mkDFunArg :: Type -> VM CoreExpr
268 mkDFunArg ty
269   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
270
271   = let name = tyConName tycon
272
273         get_dict | name == paTyConName = paDictOfType
274                  | name == prTyConName = prDictOfType
275                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
276
277     in get_dict arg
278
279 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
280
281 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
282 prCoerce repr_tc args expr
283   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
284   = do
285       pr_tc <- builtin prTyCon
286
287       let co = mkAppCoercion (mkTyConApp pr_tc [])
288                              (mkSymCoercion (mkTyConApp arg_co args))
289
290       return $ mkCoerce co expr
291
292 paDictArgType :: TyVar -> VM (Maybe Type)
293 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
294   where
295     go ty k | Just k' <- kindView k = go ty k'
296     go ty (FunTy k1 k2)
297       = do
298           tv   <- newTyVar FSLIT("a") k1
299           mty1 <- go (TyVarTy tv) k1
300           case mty1 of
301             Just ty1 -> do
302                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
303                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
304             Nothing  -> go ty k2
305
306     go ty k
307       | isLiftedTypeKind k
308       = liftM Just (mkPADictType ty)
309
310     go ty k = return Nothing
311
312 paDictOfType :: Type -> VM CoreExpr
313 paDictOfType ty = paDictOfTyApp ty_fn ty_args
314   where
315     (ty_fn, ty_args) = splitAppTys ty
316
317 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
318 paDictOfTyApp ty_fn ty_args
319   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
320 paDictOfTyApp (TyVarTy tv) ty_args
321   = do
322       dfun <- maybeV (lookupTyVarPA tv)
323       paDFunApply dfun ty_args
324 paDictOfTyApp (TyConApp tc _) ty_args
325   = do
326       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
327       paDFunApply (Var dfun) ty_args
328 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
329
330 paDFunType :: TyCon -> VM Type
331 paDFunType tc
332   = do
333       margs <- mapM paDictArgType tvs
334       res   <- mkPADictType (mkTyConApp tc arg_tys)
335       return . mkForAllTys tvs
336              $ mkFunTys [arg | Just arg <- margs] res
337   where
338     tvs = tyConTyVars tc
339     arg_tys = mkTyVarTys tvs
340
341 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
342 paDFunApply dfun tys
343   = do
344       dicts <- mapM paDictOfType tys
345       return $ mkApps (mkTyApps dfun tys) dicts
346
347 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
348 paMethod method ty
349   = do
350       fn   <- builtin method
351       dict <- paDictOfType ty
352       return $ mkApps (Var fn) [Type ty, dict]
353
354 mkPR :: Type -> VM CoreExpr
355 mkPR = paMethod mkPRVar
356
357 lengthPA :: CoreExpr -> VM CoreExpr
358 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
359   where
360     ty = splitPArrayTy (exprType x)
361
362 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
363 replicatePA len x = liftM (`mkApps` [len,x])
364                           (paMethod replicatePAVar (exprType x))
365
366 emptyPA :: Type -> VM CoreExpr
367 emptyPA = paMethod emptyPAVar
368
369 liftPA :: CoreExpr -> VM CoreExpr
370 liftPA x
371   = do
372       lc <- builtin liftingContext
373       replicatePA (Var lc) x
374
375 newLocalVVar :: FastString -> Type -> VM VVar
376 newLocalVVar fs vty
377   = do
378       lty <- mkPArrayType vty
379       vv  <- newLocalVar fs vty
380       lv  <- newLocalVar fs lty
381       return (vv,lv)
382
383 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
384 polyAbstract tvs p
385   = localV
386   $ do
387       mdicts <- mapM mk_dict_var tvs
388       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
389       p (mk_lams mdicts)
390   where
391     mk_dict_var tv = do
392                        r <- paDictArgType tv
393                        case r of
394                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
395                          Nothing -> return Nothing
396
397     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
398
399 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
400 polyApply expr tys
401   = do
402       dicts <- mapM paDictOfType tys
403       return $ expr `mkTyApps` tys `mkApps` dicts
404
405 polyVApply :: VExpr -> [Type] -> VM VExpr
406 polyVApply expr tys
407   = do
408       dicts <- mapM paDictOfType tys
409       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
410
411 hoistBinding :: Var -> CoreExpr -> VM ()
412 hoistBinding v e = updGEnv $ \env ->
413   env { global_bindings = (v,e) : global_bindings env }
414
415 hoistExpr :: FastString -> CoreExpr -> VM Var
416 hoistExpr fs expr
417   = do
418       var <- newLocalVar fs (exprType expr)
419       hoistBinding var expr
420       return var
421
422 hoistVExpr :: VExpr -> VM VVar
423 hoistVExpr (ve, le)
424   = do
425       fs <- getBindName
426       vv <- hoistExpr ('v' `consFS` fs) ve
427       lv <- hoistExpr ('l' `consFS` fs) le
428       return (vv, lv)
429
430 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
431 hoistPolyVExpr tvs p
432   = do
433       expr <- closedV . polyAbstract tvs $ \abstract ->
434               liftM (mapVect abstract) p
435       fn   <- hoistVExpr expr
436       polyVApply (vVar fn) (mkTyVarTys tvs)
437
438 takeHoisted :: VM [(Var, CoreExpr)]
439 takeHoisted
440   = do
441       env <- readGEnv id
442       setGEnv $ env { global_bindings = [] }
443       return $ global_bindings env
444
445 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
446 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
447   = do
448       dict <- paDictOfType env_ty
449       mkv  <- builtin mkClosureVar
450       mkl  <- builtin mkClosurePVar
451       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
452               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
453
454 mkClosureApp :: VExpr -> VExpr -> VM VExpr
455 mkClosureApp (vclo, lclo) (varg, larg)
456   = do
457       vapply <- builtin applyClosureVar
458       lapply <- builtin applyClosurePVar
459       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
460               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
461   where
462     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
463
464 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
465 buildClosures tvs vars [] res_ty mk_body
466   = mk_body
467 buildClosures tvs vars [arg_ty] res_ty mk_body
468   = buildClosure tvs vars arg_ty res_ty mk_body
469 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
470   = do
471       res_ty' <- mkClosureTypes arg_tys res_ty
472       arg <- newLocalVVar FSLIT("x") arg_ty
473       buildClosure tvs vars arg_ty res_ty'
474         . hoistPolyVExpr tvs
475         $ do
476             lc <- builtin liftingContext
477             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
478             return $ vLams lc (vars ++ [arg]) clo
479
480 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
481 --   where
482 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
483 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
484 --
485 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
486 buildClosure tvs vars arg_ty res_ty mk_body
487   = do
488       (env_ty, env, bind) <- buildEnv vars
489       env_bndr <- newLocalVVar FSLIT("env") env_ty
490       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
491
492       fn <- hoistPolyVExpr tvs
493           $ do
494               lc    <- builtin liftingContext
495               body  <- mk_body
496               body' <- bind (vVar env_bndr)
497                             (vVarApps lc body (vars ++ [arg_bndr]))
498               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
499
500       mkClosure arg_ty res_ty env_ty fn env
501
502 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
503 buildEnv vvs
504   = do
505       lc <- builtin liftingContext
506       let (ty, venv, vbind) = mkVectEnv tys vs
507       (lenv, lbind) <- mkLiftEnv lc tys ls
508       return (ty, (venv, lenv),
509               \(venv,lenv) (vbody,lbody) ->
510               do
511                 let vbody' = vbind venv vbody
512                 lbody' <- lbind lenv lbody
513                 return (vbody', lbody'))
514   where
515     (vs,ls) = unzip vvs
516     tys     = map idType vs
517
518 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
519 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
520 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
521 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
522                         \env body -> Case env (mkWildId ty) (exprType body)
523                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
524   where
525     ty = mkCoreTupTy tys
526
527 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
528 mkLiftEnv lc [ty] [v]
529   = return (Var v, \env body ->
530                    do
531                      len <- lengthPA (Var v)
532                      return . Let (NonRec v env)
533                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
534
535 -- NOTE: this transparently deals with empty environments
536 mkLiftEnv lc tys vs
537   = do
538       (env_tc, env_tyargs) <- parrayReprTyCon vty
539       let [env_con] = tyConDataCons env_tc
540           
541           env = Var (dataConWrapId env_con)
542                 `mkTyApps`  env_tyargs
543                 `mkVarApps` (lc : vs)
544
545           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
546                           in
547                           return $ Case scrut (mkWildId (exprType scrut))
548                                         (exprType body)
549                                         [(DataAlt env_con, lc : bndrs, body)]
550       return (env, bind)
551   where
552     vty = mkCoreTupTy tys
553
554     bndrs | null vs   = [mkWildId unitTy]
555           | otherwise = vs
556