Utility functions for vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPlusType, mkPlusTypes, mkCrossType, mkCrossTypes, mkEmbedType,
7   mkPlusAlts, mkCrosses, mkEmbed,
8   mkPADictType, mkPArrayType,
9   parrayReprTyCon, parrayReprDataCon, mkVScrut,
10   paDictArgType, paDictOfType, paDFunType,
11   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12   polyAbstract, polyApply, polyVApply,
13   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14   buildClosure, buildClosures,
15   mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Type
27 import TypeRep
28 import TyCon
29 import DataCon            ( DataCon, dataConWrapId, dataConTag )
30 import Var
31 import Id                 ( mkWildId )
32 import MkId               ( unwrapFamInstScrut )
33 import PrelNames
34 import TysWiredIn
35 import BasicTypes         ( Boxity(..) )
36
37 import Outputable
38 import FastString
39
40 import Control.Monad         ( liftM, zipWithM_ )
41
42 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
43 collectAnnTypeArgs expr = go expr []
44   where
45     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46     go e                             tys = (e, tys)
47
48 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
49 collectAnnTypeBinders expr = go [] expr
50   where
51     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
52     go bs e                           = (reverse bs, e)
53
54 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnValBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isId b = go (b:bs) e
58     go bs e                        = (reverse bs, e)
59
60 isAnnTypeArg :: AnnExpr b ann -> Bool
61 isAnnTypeArg (_, AnnType t) = True
62 isAnnTypeArg _              = False
63
64 mkDataConTag :: DataCon -> CoreExpr
65 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
66
67 isClosureTyCon :: TyCon -> Bool
68 isClosureTyCon tc = tyConName tc == closureTyConName
69
70 splitClosureTy :: Type -> (Type, Type)
71 splitClosureTy ty
72   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
73   , isClosureTyCon tc
74   = (arg_ty, res_ty)
75
76   | otherwise = pprPanic "splitClosureTy" (ppr ty)
77
78 isPArrayTyCon :: TyCon -> Bool
79 isPArrayTyCon tc = tyConName tc == parrayTyConName
80
81 splitPArrayTy :: Type -> Type
82 splitPArrayTy ty
83   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
84   , isPArrayTyCon tc
85   = arg_ty
86
87   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
88
89 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
90 mkBuiltinTyConApp get_tc tys
91   = do
92       tc <- builtin get_tc
93       return $ mkTyConApp tc tys
94
95 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
96 mkBuiltinTyConApps get_tc tys ty
97   = do
98       tc <- builtin get_tc
99       return $ foldr (mk tc) ty tys
100   where
101     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
102
103 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
104 mkBuiltinTyConApps1 get_tc dft [] = return dft
105 mkBuiltinTyConApps1 get_tc dft tys
106   = do
107       tc <- builtin get_tc
108       case tys of
109         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
110         _  -> return $ foldr1 (mk tc) tys
111   where
112     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
113
114 mkBuiltinDataConApp :: (Builtins -> DataCon) -> [CoreExpr] -> VM CoreExpr
115 mkBuiltinDataConApp get_dc args
116   = do
117       dc <- builtin get_dc
118       return $ mkConApp dc args
119
120 mkPlusType :: Type -> Type -> VM Type
121 mkPlusType ty1 ty2 = mkBuiltinTyConApp plusTyCon [ty1, ty2]
122
123 mkPlusTypes :: Type -> [Type] -> VM Type
124 mkPlusTypes = mkBuiltinTyConApps1 plusTyCon
125
126 mkPlusAlts :: [CoreExpr] -> VM [CoreExpr]
127 mkPlusAlts [] = return []
128 mkPlusAlts exprs
129   = do
130       plus_tc  <- builtin plusTyCon
131       left_dc  <- builtin leftDataCon
132       right_dc <- builtin rightDataCon
133
134       let go [expr] = ([expr], exprType expr)
135           go (expr : exprs)
136             | (alts, right_ty) <- go exprs
137             = (mkConApp left_dc [Type left_ty, Type right_ty, expr]
138                : [mkConApp right_dc [Type left_ty, Type right_ty, alt]
139                     | alt <- alts],
140                mkTyConApp plus_tc [left_ty, right_ty])
141             where
142               left_ty = exprType expr
143
144       return . fst $ go exprs
145
146 mkCrossType :: Type -> Type -> VM Type
147 mkCrossType ty1 ty2 = mkBuiltinTyConApp crossTyCon [ty1, ty2]
148
149 mkCrossTypes :: Type -> [Type] -> VM Type
150 mkCrossTypes = mkBuiltinTyConApps1 crossTyCon
151
152 mkCrosses :: [CoreExpr] -> VM CoreExpr
153 mkCrosses [] = return (Var unitDataConId)
154 mkCrosses exprs
155   = do
156       cross_tc <- builtin crossTyCon
157       cross_dc <- builtin crossDataCon
158
159       let mk (left, left_ty) (right, right_ty)
160             = (mkConApp   cross_dc [Type left_ty, Type right_ty, left, right],
161                mkTyConApp cross_tc [left_ty, right_ty])
162
163       return . fst
164              $ foldr1 mk [(expr, exprType expr) | expr <- exprs]
165
166 mkEmbedType :: Type -> VM Type
167 mkEmbedType ty = mkBuiltinTyConApp embedTyCon [ty]
168
169 mkEmbed :: CoreExpr -> VM CoreExpr
170 mkEmbed expr = mkBuiltinDataConApp embedDataCon
171                                    [Type $ exprType expr, expr]
172
173 mkClosureType :: Type -> Type -> VM Type
174 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
175
176 mkClosureTypes :: [Type] -> Type -> VM Type
177 mkClosureTypes = mkBuiltinTyConApps closureTyCon
178
179 mkPADictType :: Type -> VM Type
180 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
181
182 mkPArrayType :: Type -> VM Type
183 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
184
185 parrayReprTyCon :: Type -> VM (TyCon, [Type])
186 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
187
188 parrayReprDataCon :: Type -> VM (DataCon, [Type])
189 parrayReprDataCon ty
190   = do
191       (tc, arg_tys) <- parrayReprTyCon ty
192       let [dc] = tyConDataCons tc
193       return (dc, arg_tys)
194
195 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
196 mkVScrut (ve, le)
197   = do
198       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
199       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
200
201 paDictArgType :: TyVar -> VM (Maybe Type)
202 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
203   where
204     go ty k | Just k' <- kindView k = go ty k'
205     go ty (FunTy k1 k2)
206       = do
207           tv   <- newTyVar FSLIT("a") k1
208           mty1 <- go (TyVarTy tv) k1
209           case mty1 of
210             Just ty1 -> do
211                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
212                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
213             Nothing  -> go ty k2
214
215     go ty k
216       | isLiftedTypeKind k
217       = liftM Just (mkPADictType ty)
218
219     go ty k = return Nothing
220
221 paDictOfType :: Type -> VM CoreExpr
222 paDictOfType ty = paDictOfTyApp ty_fn ty_args
223   where
224     (ty_fn, ty_args) = splitAppTys ty
225
226 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
227 paDictOfTyApp ty_fn ty_args
228   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
229 paDictOfTyApp (TyVarTy tv) ty_args
230   = do
231       dfun <- maybeV (lookupTyVarPA tv)
232       paDFunApply dfun ty_args
233 paDictOfTyApp (TyConApp tc _) ty_args
234   = do
235       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
236       paDFunApply (Var dfun) ty_args
237 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
238
239 paDFunType :: TyCon -> VM Type
240 paDFunType tc
241   = do
242       margs <- mapM paDictArgType tvs
243       res   <- mkPADictType (mkTyConApp tc arg_tys)
244       return . mkForAllTys tvs
245              $ mkFunTys [arg | Just arg <- margs] res
246   where
247     tvs = tyConTyVars tc
248     arg_tys = mkTyVarTys tvs
249
250 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
251 paDFunApply dfun tys
252   = do
253       dicts <- mapM paDictOfType tys
254       return $ mkApps (mkTyApps dfun tys) dicts
255
256 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
257 paMethod method ty
258   = do
259       fn   <- builtin method
260       dict <- paDictOfType ty
261       return $ mkApps (Var fn) [Type ty, dict]
262
263 lengthPA :: CoreExpr -> VM CoreExpr
264 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
265   where
266     ty = splitPArrayTy (exprType x)
267
268 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
269 replicatePA len x = liftM (`mkApps` [len,x])
270                           (paMethod replicatePAVar (exprType x))
271
272 emptyPA :: Type -> VM CoreExpr
273 emptyPA = paMethod emptyPAVar
274
275 liftPA :: CoreExpr -> VM CoreExpr
276 liftPA x
277   = do
278       lc <- builtin liftingContext
279       replicatePA (Var lc) x
280
281 newLocalVVar :: FastString -> Type -> VM VVar
282 newLocalVVar fs vty
283   = do
284       lty <- mkPArrayType vty
285       vv  <- newLocalVar fs vty
286       lv  <- newLocalVar fs lty
287       return (vv,lv)
288
289 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
290 polyAbstract tvs p
291   = localV
292   $ do
293       mdicts <- mapM mk_dict_var tvs
294       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
295       p (mk_lams mdicts)
296   where
297     mk_dict_var tv = do
298                        r <- paDictArgType tv
299                        case r of
300                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
301                          Nothing -> return Nothing
302
303     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
304
305 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
306 polyApply expr tys
307   = do
308       dicts <- mapM paDictOfType tys
309       return $ expr `mkTyApps` tys `mkApps` dicts
310
311 polyVApply :: VExpr -> [Type] -> VM VExpr
312 polyVApply expr tys
313   = do
314       dicts <- mapM paDictOfType tys
315       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
316
317 hoistBinding :: Var -> CoreExpr -> VM ()
318 hoistBinding v e = updGEnv $ \env ->
319   env { global_bindings = (v,e) : global_bindings env }
320
321 hoistExpr :: FastString -> CoreExpr -> VM Var
322 hoistExpr fs expr
323   = do
324       var <- newLocalVar fs (exprType expr)
325       hoistBinding var expr
326       return var
327
328 hoistVExpr :: VExpr -> VM VVar
329 hoistVExpr (ve, le)
330   = do
331       fs <- getBindName
332       vv <- hoistExpr ('v' `consFS` fs) ve
333       lv <- hoistExpr ('l' `consFS` fs) le
334       return (vv, lv)
335
336 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
337 hoistPolyVExpr tvs p
338   = do
339       expr <- closedV . polyAbstract tvs $ \abstract ->
340               liftM (mapVect abstract) p
341       fn   <- hoistVExpr expr
342       polyVApply (vVar fn) (mkTyVarTys tvs)
343
344 takeHoisted :: VM [(Var, CoreExpr)]
345 takeHoisted
346   = do
347       env <- readGEnv id
348       setGEnv $ env { global_bindings = [] }
349       return $ global_bindings env
350
351 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
352 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
353   = do
354       dict <- paDictOfType env_ty
355       mkv  <- builtin mkClosureVar
356       mkl  <- builtin mkClosurePVar
357       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
358               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
359
360 mkClosureApp :: VExpr -> VExpr -> VM VExpr
361 mkClosureApp (vclo, lclo) (varg, larg)
362   = do
363       vapply <- builtin applyClosureVar
364       lapply <- builtin applyClosurePVar
365       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
366               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
367   where
368     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
369
370 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
371 buildClosures tvs vars [] res_ty mk_body
372   = mk_body
373 buildClosures tvs vars [arg_ty] res_ty mk_body
374   = buildClosure tvs vars arg_ty res_ty mk_body
375 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
376   = do
377       res_ty' <- mkClosureTypes arg_tys res_ty
378       arg <- newLocalVVar FSLIT("x") arg_ty
379       buildClosure tvs vars arg_ty res_ty'
380         . hoistPolyVExpr tvs
381         $ do
382             lc <- builtin liftingContext
383             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
384             return $ vLams lc (vars ++ [arg]) clo
385
386 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
387 --   where
388 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
389 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
390 --
391 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
392 buildClosure tvs vars arg_ty res_ty mk_body
393   = do
394       (env_ty, env, bind) <- buildEnv vars
395       env_bndr <- newLocalVVar FSLIT("env") env_ty
396       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
397
398       fn <- hoistPolyVExpr tvs
399           $ do
400               lc    <- builtin liftingContext
401               body  <- mk_body
402               body' <- bind (vVar env_bndr)
403                             (vVarApps lc body (vars ++ [arg_bndr]))
404               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
405
406       mkClosure arg_ty res_ty env_ty fn env
407
408 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
409 buildEnv vvs
410   = do
411       lc <- builtin liftingContext
412       let (ty, venv, vbind) = mkVectEnv tys vs
413       (lenv, lbind) <- mkLiftEnv lc tys ls
414       return (ty, (venv, lenv),
415               \(venv,lenv) (vbody,lbody) ->
416               do
417                 let vbody' = vbind venv vbody
418                 lbody' <- lbind lenv lbody
419                 return (vbody', lbody'))
420   where
421     (vs,ls) = unzip vvs
422     tys     = map idType vs
423
424 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
425 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
426 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
427 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
428                         \env body -> Case env (mkWildId ty) (exprType body)
429                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
430   where
431     ty = mkCoreTupTy tys
432
433 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
434 mkLiftEnv lc [ty] [v]
435   = return (Var v, \env body ->
436                    do
437                      len <- lengthPA (Var v)
438                      return . Let (NonRec v env)
439                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
440
441 -- NOTE: this transparently deals with empty environments
442 mkLiftEnv lc tys vs
443   = do
444       (env_tc, env_tyargs) <- parrayReprTyCon vty
445       let [env_con] = tyConDataCons env_tc
446           
447           env = Var (dataConWrapId env_con)
448                 `mkTyApps`  env_tyargs
449                 `mkVarApps` (lc : vs)
450
451           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
452                           in
453                           return $ Case scrut (mkWildId (exprType scrut))
454                                         (exprType body)
455                                         [(DataAlt env_con, lc : bndrs, body)]
456       return (env, bind)
457   where
458     vty = mkCoreTupTy tys
459
460     bndrs | null vs   = [mkWildId unitTy]
461           | otherwise = vs
462