Trace more vectorisation failures
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   splitClosureTy,
5   mkPADictType, mkPArrayType,
6   paDictArgType, paDictOfType, paDFunType,
7   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
8   polyAbstract, polyApply, polyVApply,
9   lookupPArrayFamInst,
10   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
11   buildClosure, buildClosures,
12   mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon            ( dataConWrapId )
27 import Var
28 import Id                 ( mkWildId )
29 import MkId               ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes         ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad         ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41   where
42     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43     go e                             tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47   where
48     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49     go bs e                           = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isId b = go (b:bs) e
55     go bs e                        = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _              = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67   , isClosureTyCon tc
68   = (arg_ty, res_ty)
69
70   | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78   , isPArrayTyCon tc
79   = arg_ty
80
81   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85   = do
86       tc <- builtin closureTyCon
87       return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91   = do
92       tc <- builtin closureTyCon
93       return $ foldr (mk tc) res_ty arg_tys
94   where
95     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99   = do
100       tc <- builtin paTyCon
101       return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105   = do
106       tc <- builtin parrayTyCon
107       return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111   where
112     go ty k | Just k' <- kindView k = go ty k'
113     go ty (FunTy k1 k2)
114       = do
115           tv   <- newTyVar FSLIT("a") k1
116           mty1 <- go (TyVarTy tv) k1
117           case mty1 of
118             Just ty1 -> do
119                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
120                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
121             Nothing  -> go ty k2
122
123     go ty k
124       | isLiftedTypeKind k
125       = liftM Just (mkPADictType ty)
126
127     go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131   where
132     (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138   = do
139       dfun <- maybeV (lookupTyVarPA tv)
140       paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142   = do
143       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
144       paDFunApply (Var dfun) ty_args
145 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
146
147 paDFunType :: TyCon -> VM Type
148 paDFunType tc
149   = do
150       margs <- mapM paDictArgType tvs
151       res   <- mkPADictType (mkTyConApp tc arg_tys)
152       return . mkForAllTys tvs
153              $ mkFunTys [arg | Just arg <- margs] res
154   where
155     tvs = tyConTyVars tc
156     arg_tys = mkTyVarTys tvs
157
158 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
159 paDFunApply dfun tys
160   = do
161       dicts <- mapM paDictOfType tys
162       return $ mkApps (mkTyApps dfun tys) dicts
163
164 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
165 paMethod method ty
166   = do
167       fn   <- builtin method
168       dict <- paDictOfType ty
169       return $ mkApps (Var fn) [Type ty, dict]
170
171 lengthPA :: CoreExpr -> VM CoreExpr
172 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
173   where
174     ty = splitPArrayTy (exprType x)
175
176 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
177 replicatePA len x = liftM (`mkApps` [len,x])
178                           (paMethod replicatePAVar (exprType x))
179
180 emptyPA :: Type -> VM CoreExpr
181 emptyPA = paMethod emptyPAVar
182
183 liftPA :: CoreExpr -> VM CoreExpr
184 liftPA x
185   = do
186       lc <- builtin liftingContext
187       replicatePA (Var lc) x
188
189 newLocalVVar :: FastString -> Type -> VM VVar
190 newLocalVVar fs vty
191   = do
192       lty <- mkPArrayType vty
193       vv  <- newLocalVar fs vty
194       lv  <- newLocalVar fs lty
195       return (vv,lv)
196
197 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
198 polyAbstract tvs p
199   = localV
200   $ do
201       mdicts <- mapM mk_dict_var tvs
202       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
203       p (mk_lams mdicts)
204   where
205     mk_dict_var tv = do
206                        r <- paDictArgType tv
207                        case r of
208                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
209                          Nothing -> return Nothing
210
211     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
212
213 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
214 polyApply expr tys
215   = do
216       dicts <- mapM paDictOfType tys
217       return $ expr `mkTyApps` tys `mkApps` dicts
218
219 polyVApply :: VExpr -> [Type] -> VM VExpr
220 polyVApply expr tys
221   = do
222       dicts <- mapM paDictOfType tys
223       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
224
225 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
226 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
227
228 hoistBinding :: Var -> CoreExpr -> VM ()
229 hoistBinding v e = updGEnv $ \env ->
230   env { global_bindings = (v,e) : global_bindings env }
231
232 hoistExpr :: FastString -> CoreExpr -> VM Var
233 hoistExpr fs expr
234   = do
235       var <- newLocalVar fs (exprType expr)
236       hoistBinding var expr
237       return var
238
239 hoistVExpr :: VExpr -> VM VVar
240 hoistVExpr (ve, le)
241   = do
242       fs <- getBindName
243       vv <- hoistExpr ('v' `consFS` fs) ve
244       lv <- hoistExpr ('l' `consFS` fs) le
245       return (vv, lv)
246
247 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
248 hoistPolyVExpr tvs p
249   = do
250       expr <- closedV . polyAbstract tvs $ \abstract ->
251               liftM (mapVect abstract) p
252       fn   <- hoistVExpr expr
253       polyVApply (vVar fn) (mkTyVarTys tvs)
254
255 takeHoisted :: VM [(Var, CoreExpr)]
256 takeHoisted
257   = do
258       env <- readGEnv id
259       setGEnv $ env { global_bindings = [] }
260       return $ global_bindings env
261
262 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
263 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
264   = do
265       dict <- paDictOfType env_ty
266       mkv  <- builtin mkClosureVar
267       mkl  <- builtin mkClosurePVar
268       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
269               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
270
271 mkClosureApp :: VExpr -> VExpr -> VM VExpr
272 mkClosureApp (vclo, lclo) (varg, larg)
273   = do
274       vapply <- builtin applyClosureVar
275       lapply <- builtin applyClosurePVar
276       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
277               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
278   where
279     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
280
281 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
282 buildClosures tvs vars [arg_ty] res_ty mk_body
283   = buildClosure tvs vars arg_ty res_ty mk_body
284 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
285   = do
286       res_ty' <- mkClosureTypes arg_tys res_ty
287       arg <- newLocalVVar FSLIT("x") arg_ty
288       buildClosure tvs vars arg_ty res_ty'
289         . hoistPolyVExpr tvs
290         $ do
291             lc <- builtin liftingContext
292             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
293             return $ vLams lc (vars ++ [arg]) clo
294
295 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
296 --   where
297 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
298 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
299 --
300 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
301 buildClosure tvs vars arg_ty res_ty mk_body
302   = do
303       (env_ty, env, bind) <- buildEnv vars
304       env_bndr <- newLocalVVar FSLIT("env") env_ty
305       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
306
307       fn <- hoistPolyVExpr tvs
308           $ do
309               lc    <- builtin liftingContext
310               body  <- mk_body
311               body' <- bind (vVar env_bndr)
312                             (vVarApps lc body (vars ++ [arg_bndr]))
313               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
314
315       mkClosure arg_ty res_ty env_ty fn env
316
317 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
318 buildEnv vvs
319   = do
320       lc <- builtin liftingContext
321       let (ty, venv, vbind) = mkVectEnv tys vs
322       (lenv, lbind) <- mkLiftEnv lc tys ls
323       return (ty, (venv, lenv),
324               \(venv,lenv) (vbody,lbody) ->
325               do
326                 let vbody' = vbind venv vbody
327                 lbody' <- lbind lenv lbody
328                 return (vbody', lbody'))
329   where
330     (vs,ls) = unzip vvs
331     tys     = map idType vs
332
333 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
334 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
335 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
336 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
337                         \env body -> Case env (mkWildId ty) (exprType body)
338                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
339   where
340     ty = mkCoreTupTy tys
341
342 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
343 mkLiftEnv lc [ty] [v]
344   = return (Var v, \env body ->
345                    do
346                      len <- lengthPA (Var v)
347                      return . Let (NonRec v env)
348                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
349
350 -- NOTE: this transparently deals with empty environments
351 mkLiftEnv lc tys vs
352   = do
353       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
354       let [env_con] = tyConDataCons env_tc
355           
356           env = Var (dataConWrapId env_con)
357                 `mkTyApps`  env_tyargs
358                 `mkVarApps` (lc : vs)
359
360           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
361                           in
362                           return $ Case scrut (mkWildId (exprType scrut))
363                                         (exprType body)
364                                         [(DataAlt env_con, lc : bndrs, body)]
365       return (env, bind)
366   where
367     vty = mkCoreTupTy tys
368
369     bndrs | null vs   = [mkWildId unitTy]
370           | otherwise = vs
371