8d9e904498a10a62b4d834c0d9ee4d74d3185095
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 --     http://hackage.haskell.org/trac/ghc/wiki/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10   collectAnnValBinders,
11   mkDataConTag, mkDataConTagLit,
12
13   mkBuiltinCo,
14   mkPADictType, mkPArrayType, mkPReprType,
15
16   parrayReprTyCon, parrayReprDataCon, mkVScrut,
17   prDFunOfTyCon,
18   paDictArgType, paDictOfType, paDFunType,
19   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
20   polyAbstract, polyApply, polyVApply,
21   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
22   buildClosure, buildClosures,
23   mkClosureApp
24 ) where
25
26 #include "HsVersions.h"
27
28 import VectCore
29 import VectMonad
30
31 import DsUtils
32 import CoreSyn
33 import CoreUtils
34 import Coercion
35 import Type
36 import TypeRep
37 import TyCon
38 import DataCon
39 import Var
40 import Id                 ( mkWildId )
41 import MkId               ( unwrapFamInstScrut )
42 import Name               ( Name )
43 import PrelNames
44 import TysWiredIn
45 import TysPrim            ( intPrimTy )
46 import BasicTypes         ( Boxity(..) )
47 import Literal            ( Literal, mkMachInt )
48
49 import Outputable
50 import FastString
51
52 import Data.List             ( zipWith4 )
53 import Control.Monad         ( liftM, liftM2, zipWithM_ )
54
55 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
56 collectAnnTypeArgs expr = go expr []
57   where
58     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
59     go e                             tys = (e, tys)
60
61 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnTypeBinders expr = go [] expr
63   where
64     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
65     go bs e                           = (reverse bs, e)
66
67 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
68 collectAnnValBinders expr = go [] expr
69   where
70     go bs (_, AnnLam b e) | isId b = go (b:bs) e
71     go bs e                        = (reverse bs, e)
72
73 isAnnTypeArg :: AnnExpr b ann -> Bool
74 isAnnTypeArg (_, AnnType t) = True
75 isAnnTypeArg _              = False
76
77 mkDataConTagLit :: DataCon -> Literal
78 mkDataConTagLit con
79   = mkMachInt . toInteger $ dataConTag con - fIRST_TAG
80
81 mkDataConTag :: DataCon -> CoreExpr
82 mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG)
83
84 splitPrimTyCon :: Type -> Maybe TyCon
85 splitPrimTyCon ty
86   | Just (tycon, []) <- splitTyConApp_maybe ty
87   , isPrimTyCon tycon
88   = Just tycon
89
90   | otherwise = Nothing
91
92 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
93 mkBuiltinTyConApp get_tc tys
94   = do
95       tc <- builtin get_tc
96       return $ mkTyConApp tc tys
97
98 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
99 mkBuiltinTyConApps get_tc tys ty
100   = do
101       tc <- builtin get_tc
102       return $ foldr (mk tc) ty tys
103   where
104     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
105
106 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
107 mkBuiltinTyConApps1 get_tc dft [] = return dft
108 mkBuiltinTyConApps1 get_tc dft tys
109   = do
110       tc <- builtin get_tc
111       case tys of
112         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
113         _  -> return $ foldr1 (mk tc) tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkClosureType :: Type -> Type -> VM Type
118 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
119
120 mkClosureTypes :: [Type] -> Type -> VM Type
121 mkClosureTypes = mkBuiltinTyConApps closureTyCon
122
123 mkPReprType :: Type -> VM Type
124 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
125
126 mkPADictType :: Type -> VM Type
127 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
128
129 mkPArrayType :: Type -> VM Type
130 mkPArrayType ty
131   | Just tycon <- splitPrimTyCon ty
132   = do
133       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
134            $ lookupPrimPArray tycon
135       return $ mkTyConApp arr []
136 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
137
138 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
139 mkBuiltinCo get_tc
140   = do
141       tc <- builtin get_tc
142       return $ mkTyConApp tc []
143
144 parrayReprTyCon :: Type -> VM (TyCon, [Type])
145 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
146
147 parrayReprDataCon :: Type -> VM (DataCon, [Type])
148 parrayReprDataCon ty
149   = do
150       (tc, arg_tys) <- parrayReprTyCon ty
151       let [dc] = tyConDataCons tc
152       return (dc, arg_tys)
153
154 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
155 mkVScrut (ve, le)
156   = do
157       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
158       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
159
160 prDFunOfTyCon :: TyCon -> VM CoreExpr
161 prDFunOfTyCon tycon
162   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
163
164 paDictArgType :: TyVar -> VM (Maybe Type)
165 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
166   where
167     go ty k | Just k' <- kindView k = go ty k'
168     go ty (FunTy k1 k2)
169       = do
170           tv   <- newTyVar FSLIT("a") k1
171           mty1 <- go (TyVarTy tv) k1
172           case mty1 of
173             Just ty1 -> do
174                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
175                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
176             Nothing  -> go ty k2
177
178     go ty k
179       | isLiftedTypeKind k
180       = liftM Just (mkPADictType ty)
181
182     go ty k = return Nothing
183
184 paDictOfType :: Type -> VM CoreExpr
185 paDictOfType ty = paDictOfTyApp ty_fn ty_args
186   where
187     (ty_fn, ty_args) = splitAppTys ty
188
189 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
190 paDictOfTyApp ty_fn ty_args
191   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
192 paDictOfTyApp (TyVarTy tv) ty_args
193   = do
194       dfun <- maybeV (lookupTyVarPA tv)
195       paDFunApply dfun ty_args
196 paDictOfTyApp (TyConApp tc _) ty_args
197   = do
198       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
199       paDFunApply (Var dfun) ty_args
200 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
201
202 paDFunType :: TyCon -> VM Type
203 paDFunType tc
204   = do
205       margs <- mapM paDictArgType tvs
206       res   <- mkPADictType (mkTyConApp tc arg_tys)
207       return . mkForAllTys tvs
208              $ mkFunTys [arg | Just arg <- margs] res
209   where
210     tvs = tyConTyVars tc
211     arg_tys = mkTyVarTys tvs
212
213 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
214 paDFunApply dfun tys
215   = do
216       dicts <- mapM paDictOfType tys
217       return $ mkApps (mkTyApps dfun tys) dicts
218
219 type PAMethod = (Builtins -> Var, String)
220
221 pa_length    = (lengthPAVar,    "lengthPA")
222 pa_replicate = (replicatePAVar, "replicatePA")
223 pa_empty     = (emptyPAVar,     "emptyPA")
224
225 paMethod :: PAMethod -> Type -> VM CoreExpr
226 paMethod (method, name) ty
227   | Just tycon <- splitPrimTyCon ty
228   = do
229       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
230           $ lookupPrimMethod tycon name
231       return (Var fn)
232
233 paMethod (method, name) ty
234   = do
235       fn   <- builtin method
236       dict <- paDictOfType ty
237       return $ mkApps (Var fn) [Type ty, dict]
238
239 mkPR :: Type -> VM CoreExpr
240 mkPR ty
241   = do
242       fn   <- builtin mkPRVar
243       dict <- paDictOfType ty
244       return $ mkApps (Var fn) [Type ty, dict]
245
246 lengthPA :: Type -> CoreExpr -> VM CoreExpr
247 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
248
249 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
250 replicatePA len x = liftM (`mkApps` [len,x])
251                           (paMethod pa_replicate (exprType x))
252
253 emptyPA :: Type -> VM CoreExpr
254 emptyPA = paMethod pa_empty
255
256 liftPA :: CoreExpr -> VM CoreExpr
257 liftPA x
258   = do
259       lc <- builtin liftingContext
260       replicatePA (Var lc) x
261
262 newLocalVVar :: FastString -> Type -> VM VVar
263 newLocalVVar fs vty
264   = do
265       lty <- mkPArrayType vty
266       vv  <- newLocalVar fs vty
267       lv  <- newLocalVar fs lty
268       return (vv,lv)
269
270 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
271 polyAbstract tvs p
272   = localV
273   $ do
274       mdicts <- mapM mk_dict_var tvs
275       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
276       p (mk_lams mdicts)
277   where
278     mk_dict_var tv = do
279                        r <- paDictArgType tv
280                        case r of
281                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
282                          Nothing -> return Nothing
283
284     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
285
286 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
287 polyApply expr tys
288   = do
289       dicts <- mapM paDictOfType tys
290       return $ expr `mkTyApps` tys `mkApps` dicts
291
292 polyVApply :: VExpr -> [Type] -> VM VExpr
293 polyVApply expr tys
294   = do
295       dicts <- mapM paDictOfType tys
296       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
297
298 hoistBinding :: Var -> CoreExpr -> VM ()
299 hoistBinding v e = updGEnv $ \env ->
300   env { global_bindings = (v,e) : global_bindings env }
301
302 hoistExpr :: FastString -> CoreExpr -> VM Var
303 hoistExpr fs expr
304   = do
305       var <- newLocalVar fs (exprType expr)
306       hoistBinding var expr
307       return var
308
309 hoistVExpr :: VExpr -> VM VVar
310 hoistVExpr (ve, le)
311   = do
312       fs <- getBindName
313       vv <- hoistExpr ('v' `consFS` fs) ve
314       lv <- hoistExpr ('l' `consFS` fs) le
315       return (vv, lv)
316
317 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
318 hoistPolyVExpr tvs p
319   = do
320       expr <- closedV . polyAbstract tvs $ \abstract ->
321               liftM (mapVect abstract) p
322       fn   <- hoistVExpr expr
323       polyVApply (vVar fn) (mkTyVarTys tvs)
324
325 takeHoisted :: VM [(Var, CoreExpr)]
326 takeHoisted
327   = do
328       env <- readGEnv id
329       setGEnv $ env { global_bindings = [] }
330       return $ global_bindings env
331
332 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
333 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
334   = do
335       dict <- paDictOfType env_ty
336       mkv  <- builtin mkClosureVar
337       mkl  <- builtin mkClosurePVar
338       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
339               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
340
341 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
342 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
343   = do
344       vapply <- builtin applyClosureVar
345       lapply <- builtin applyClosurePVar
346       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
347               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
348
349 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
350 buildClosures tvs vars [] res_ty mk_body
351   = mk_body
352 buildClosures tvs vars [arg_ty] res_ty mk_body
353   = buildClosure tvs vars arg_ty res_ty mk_body
354 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
355   = do
356       res_ty' <- mkClosureTypes arg_tys res_ty
357       arg <- newLocalVVar FSLIT("x") arg_ty
358       buildClosure tvs vars arg_ty res_ty'
359         . hoistPolyVExpr tvs
360         $ do
361             lc <- builtin liftingContext
362             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
363             return $ vLams lc (vars ++ [arg]) clo
364
365 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
366 --   where
367 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
368 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
369 --
370 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
371 buildClosure tvs vars arg_ty res_ty mk_body
372   = do
373       (env_ty, env, bind) <- buildEnv vars
374       env_bndr <- newLocalVVar FSLIT("env") env_ty
375       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
376
377       fn <- hoistPolyVExpr tvs
378           $ do
379               lc    <- builtin liftingContext
380               body  <- mk_body
381               body' <- bind (vVar env_bndr)
382                             (vVarApps lc body (vars ++ [arg_bndr]))
383               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
384
385       mkClosure arg_ty res_ty env_ty fn env
386
387 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
388 buildEnv vvs
389   = do
390       lc <- builtin liftingContext
391       let (ty, venv, vbind) = mkVectEnv tys vs
392       (lenv, lbind) <- mkLiftEnv lc tys ls
393       return (ty, (venv, lenv),
394               \(venv,lenv) (vbody,lbody) ->
395               do
396                 let vbody' = vbind venv vbody
397                 lbody' <- lbind lenv lbody
398                 return (vbody', lbody'))
399   where
400     (vs,ls) = unzip vvs
401     tys     = map idType vs
402
403 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
404 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
405 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
406 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
407                         \env body -> Case env (mkWildId ty) (exprType body)
408                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
409   where
410     ty = mkCoreTupTy tys
411
412 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
413 mkLiftEnv lc [ty] [v]
414   = return (Var v, \env body ->
415                    do
416                      len <- lengthPA ty (Var v)
417                      return . Let (NonRec v env)
418                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
419
420 -- NOTE: this transparently deals with empty environments
421 mkLiftEnv lc tys vs
422   = do
423       (env_tc, env_tyargs) <- parrayReprTyCon vty
424       let [env_con] = tyConDataCons env_tc
425           
426           env = Var (dataConWrapId env_con)
427                 `mkTyApps`  env_tyargs
428                 `mkVarApps` (lc : vs)
429
430           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
431                           in
432                           return $ Case scrut (mkWildId (exprType scrut))
433                                         (exprType body)
434                                         [(DataAlt env_con, lc : bndrs, body)]
435       return (env, bind)
436   where
437     vty = mkCoreTupTy tys
438
439     bndrs | null vs   = [mkWildId unitTy]
440           | otherwise = vs
441