Follow introduction of MkCore in VectUtils
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 import VectCore
22 import VectMonad
23
24 import MkCore
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import Id                 ( mkWildId )
34 import MkId               ( unwrapFamInstScrut )
35 import TysWiredIn
36 import BasicTypes         ( Boxity(..) )
37 import Literal            ( Literal, mkMachInt )
38
39 import Outputable
40 import FastString
41
42 import Control.Monad
43
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47   where
48     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49     go e                             tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55     go bs e                           = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59   where
60     go bs (_, AnnLam b e) | isId b = go (b:bs) e
61     go bs e                        = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType _) = True
65 isAnnTypeArg _              = False
66
67 dataConTagZ :: DataCon -> Int
68 dataConTagZ con = dataConTag con - fIRST_TAG
69
70 mkDataConTagLit :: DataCon -> Literal
71 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
72
73 mkDataConTag :: DataCon -> CoreExpr
74 mkDataConTag = mkIntLitInt . dataConTagZ
75
76 splitPrimTyCon :: Type -> Maybe TyCon
77 splitPrimTyCon ty
78   | Just (tycon, []) <- splitTyConApp_maybe ty
79   , isPrimTyCon tycon
80   = Just tycon
81
82   | otherwise = Nothing
83
84 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
85 mkBuiltinTyConApp get_tc tys
86   = do
87       tc <- builtin get_tc
88       return $ mkTyConApp tc tys
89
90 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
91 mkBuiltinTyConApps get_tc tys ty
92   = do
93       tc <- builtin get_tc
94       return $ foldr (mk tc) ty tys
95   where
96     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
97
98 {-
99 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
100 mkBuiltinTyConApps1 _      dft [] = return dft
101 mkBuiltinTyConApps1 get_tc _   tys
102   = do
103       tc <- builtin get_tc
104       case tys of
105         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
106         _  -> return $ foldr1 (mk tc) tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkClosureType :: Type -> Type -> VM Type
111 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
112 -}
113
114 mkClosureTypes :: [Type] -> Type -> VM Type
115 mkClosureTypes = mkBuiltinTyConApps closureTyCon
116
117 mkPReprType :: Type -> VM Type
118 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
119
120 mkPADictType :: Type -> VM Type
121 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
122
123 mkPArrayType :: Type -> VM Type
124 mkPArrayType ty
125   | Just tycon <- splitPrimTyCon ty
126   = do
127       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
128            $ lookupPrimPArray tycon
129       return $ mkTyConApp arr []
130 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
131
132 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
133 mkBuiltinCo get_tc
134   = do
135       tc <- builtin get_tc
136       return $ mkTyConApp tc []
137
138 parrayReprTyCon :: Type -> VM (TyCon, [Type])
139 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
140
141 parrayReprDataCon :: Type -> VM (DataCon, [Type])
142 parrayReprDataCon ty
143   = do
144       (tc, arg_tys) <- parrayReprTyCon ty
145       let [dc] = tyConDataCons tc
146       return (dc, arg_tys)
147
148 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
149 mkVScrut (ve, le)
150   = do
151       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
152       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
153
154 prDFunOfTyCon :: TyCon -> VM CoreExpr
155 prDFunOfTyCon tycon
156   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
157
158 paDictArgType :: TyVar -> VM (Maybe Type)
159 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
160   where
161     go ty k | Just k' <- kindView k = go ty k'
162     go ty (FunTy k1 k2)
163       = do
164           tv   <- newTyVar (fsLit "a") k1
165           mty1 <- go (TyVarTy tv) k1
166           case mty1 of
167             Just ty1 -> do
168                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
169                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
170             Nothing  -> go ty k2
171
172     go ty k
173       | isLiftedTypeKind k
174       = liftM Just (mkPADictType ty)
175
176     go _ _ = return Nothing
177
178 paDictOfType :: Type -> VM CoreExpr
179 paDictOfType ty = paDictOfTyApp ty_fn ty_args
180   where
181     (ty_fn, ty_args) = splitAppTys ty
182
183 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
184 paDictOfTyApp ty_fn ty_args
185   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
186 paDictOfTyApp (TyVarTy tv) ty_args
187   = do
188       dfun <- maybeV (lookupTyVarPA tv)
189       paDFunApply dfun ty_args
190 paDictOfTyApp (TyConApp tc _) ty_args
191   = do
192       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
193       paDFunApply (Var dfun) ty_args
194 paDictOfTyApp ty _ = pprPanic "paDictOfTyApp" (ppr ty)
195
196 paDFunType :: TyCon -> VM Type
197 paDFunType tc
198   = do
199       margs <- mapM paDictArgType tvs
200       res   <- mkPADictType (mkTyConApp tc arg_tys)
201       return . mkForAllTys tvs
202              $ mkFunTys [arg | Just arg <- margs] res
203   where
204     tvs = tyConTyVars tc
205     arg_tys = mkTyVarTys tvs
206
207 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
208 paDFunApply dfun tys
209   = do
210       dicts <- mapM paDictOfType tys
211       return $ mkApps (mkTyApps dfun tys) dicts
212
213 type PAMethod = (Builtins -> Var, String)
214
215 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
216 pa_length    = (lengthPAVar,    "lengthPA")
217 pa_replicate = (replicatePAVar, "replicatePA")
218 pa_empty     = (emptyPAVar,     "emptyPA")
219 pa_pack      = (packPAVar,      "packPA")
220
221 paMethod :: PAMethod -> Type -> VM CoreExpr
222 paMethod (_method, name) ty
223   | Just tycon <- splitPrimTyCon ty
224   = do
225       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
226           $ lookupPrimMethod tycon name
227       return (Var fn)
228
229 paMethod (method, _name) ty
230   = do
231       fn   <- builtin method
232       dict <- paDictOfType ty
233       return $ mkApps (Var fn) [Type ty, dict]
234
235 mkPR :: Type -> VM CoreExpr
236 mkPR ty
237   = do
238       fn   <- builtin mkPRVar
239       dict <- paDictOfType ty
240       return $ mkApps (Var fn) [Type ty, dict]
241
242 lengthPA :: Type -> CoreExpr -> VM CoreExpr
243 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
244
245 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
246 replicatePA len x = liftM (`mkApps` [len,x])
247                           (paMethod pa_replicate (exprType x))
248
249 emptyPA :: Type -> VM CoreExpr
250 emptyPA = paMethod pa_empty
251
252 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
253 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
254                              (paMethod pa_pack ty)
255
256 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
257           -> VM CoreExpr
258 combinePA ty len sel is xs
259   = liftM (`mkApps` (len : sel : is : xs))
260           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
261   where
262     n = length xs
263
264 liftPA :: CoreExpr -> VM CoreExpr
265 liftPA x
266   = do
267       lc <- builtin liftingContext
268       replicatePA (Var lc) x
269
270 newLocalVVar :: FastString -> Type -> VM VVar
271 newLocalVVar fs vty
272   = do
273       lty <- mkPArrayType vty
274       vv  <- newLocalVar fs vty
275       lv  <- newLocalVar fs lty
276       return (vv,lv)
277
278 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
279 polyAbstract tvs p
280   = localV
281   $ do
282       mdicts <- mapM mk_dict_var tvs
283       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
284       p (mk_lams mdicts)
285   where
286     mk_dict_var tv = do
287                        r <- paDictArgType tv
288                        case r of
289                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
290                          Nothing -> return Nothing
291
292     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
293
294 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
295 polyApply expr tys
296   = do
297       dicts <- mapM paDictOfType tys
298       return $ expr `mkTyApps` tys `mkApps` dicts
299
300 polyVApply :: VExpr -> [Type] -> VM VExpr
301 polyVApply expr tys
302   = do
303       dicts <- mapM paDictOfType tys
304       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
305
306 hoistBinding :: Var -> CoreExpr -> VM ()
307 hoistBinding v e = updGEnv $ \env ->
308   env { global_bindings = (v,e) : global_bindings env }
309
310 hoistExpr :: FastString -> CoreExpr -> VM Var
311 hoistExpr fs expr
312   = do
313       var <- newLocalVar fs (exprType expr)
314       hoistBinding var expr
315       return var
316
317 hoistVExpr :: VExpr -> VM VVar
318 hoistVExpr (ve, le)
319   = do
320       fs <- getBindName
321       vv <- hoistExpr ('v' `consFS` fs) ve
322       lv <- hoistExpr ('l' `consFS` fs) le
323       return (vv, lv)
324
325 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
326 hoistPolyVExpr tvs p
327   = do
328       expr <- closedV . polyAbstract tvs $ \abstract ->
329               liftM (mapVect abstract) p
330       fn   <- hoistVExpr expr
331       polyVApply (vVar fn) (mkTyVarTys tvs)
332
333 takeHoisted :: VM [(Var, CoreExpr)]
334 takeHoisted
335   = do
336       env <- readGEnv id
337       setGEnv $ env { global_bindings = [] }
338       return $ global_bindings env
339
340 {-
341 boxExpr :: Type -> VExpr -> VM VExpr
342 boxExpr ty (vexpr, lexpr)
343   | Just (tycon, []) <- splitTyConApp_maybe ty
344   , isUnLiftedTyCon tycon
345   = do
346       r <- lookupBoxedTyCon tycon
347       case r of
348         Just tycon' -> let [dc] = tyConDataCons tycon'
349                        in
350                        return (mkConApp dc [vexpr], lexpr)
351         Nothing     -> return (vexpr, lexpr)
352 -}
353
354 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
355 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
356   = do
357       dict <- paDictOfType env_ty
358       mkv  <- builtin mkClosureVar
359       mkl  <- builtin mkClosurePVar
360       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
361               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
362
363 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
364 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
365   = do
366       vapply <- builtin applyClosureVar
367       lapply <- builtin applyClosurePVar
368       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
369               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
370
371 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
372 buildClosures _   _    [] _ mk_body
373   = mk_body
374 buildClosures tvs vars [arg_ty] res_ty mk_body
375   = buildClosure tvs vars arg_ty res_ty mk_body
376 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
377   = do
378       res_ty' <- mkClosureTypes arg_tys res_ty
379       arg <- newLocalVVar (fsLit "x") arg_ty
380       buildClosure tvs vars arg_ty res_ty'
381         . hoistPolyVExpr tvs
382         $ do
383             lc <- builtin liftingContext
384             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
385             return $ vLams lc (vars ++ [arg]) clo
386
387 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
388 --   where
389 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
390 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
391 --
392 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
393 buildClosure tvs vars arg_ty res_ty mk_body
394   = do
395       (env_ty, env, bind) <- buildEnv vars
396       env_bndr <- newLocalVVar (fsLit "env") env_ty
397       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
398
399       fn <- hoistPolyVExpr tvs
400           $ do
401               lc    <- builtin liftingContext
402               body  <- mk_body
403               body' <- bind (vVar env_bndr)
404                             (vVarApps lc body (vars ++ [arg_bndr]))
405               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
406
407       mkClosure arg_ty res_ty env_ty fn env
408
409 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
410 buildEnv vvs
411   = do
412       lc <- builtin liftingContext
413       let (ty, venv, vbind) = mkVectEnv tys vs
414       (lenv, lbind) <- mkLiftEnv lc tys ls
415       return (ty, (venv, lenv),
416               \(venv,lenv) (vbody,lbody) ->
417               do
418                 let vbody' = vbind venv vbody
419                 lbody' <- lbind lenv lbody
420                 return (vbody', lbody'))
421   where
422     (vs,ls) = unzip vvs
423     tys     = map varType vs
424
425 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
426 mkVectEnv []   []  = (unitTy, Var unitDataConId, \_ body -> body)
427 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
428 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
429                         \env body -> Case env (mkWildId ty) (exprType body)
430                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
431   where
432     ty = mkCoreTupTy tys
433
434 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
435 mkLiftEnv lc [ty] [v]
436   = return (Var v, \env body ->
437                    do
438                      len <- lengthPA ty (Var v)
439                      return . Let (NonRec v env)
440                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
441
442 -- NOTE: this transparently deals with empty environments
443 mkLiftEnv lc tys vs
444   = do
445       (env_tc, env_tyargs) <- parrayReprTyCon vty
446
447       bndrs <- if null vs then do
448                                  v <- newDummyVar unitTy
449                                  return [v]
450                           else return vs
451       let [env_con] = tyConDataCons env_tc
452           
453           env = Var (dataConWrapId env_con)
454                 `mkTyApps`  env_tyargs
455                 `mkApps`    (Var lc : args)
456
457           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
458                           in
459                           return $ Case scrut (mkWildId (exprType scrut))
460                                         (exprType body)
461                                         [(DataAlt env_con, lc : bndrs, body)]
462       return (env, bind)
463   where
464     vty = mkCoreTupTy tys
465
466     args  | null vs   = [Var unitDataConId]
467           | otherwise = map Var vs
468