Fixed warnings in vectorise/VectUtils
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import TysWiredIn
38 import BasicTypes         ( Boxity(..) )
39 import Literal            ( Literal, mkMachInt )
40
41 import Outputable
42 import FastString
43
44 import Control.Monad
45
46
47 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
48 collectAnnTypeArgs expr = go expr []
49   where
50     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
51     go e                             tys = (e, tys)
52
53 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
54 collectAnnTypeBinders expr = go [] expr
55   where
56     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
57     go bs e                           = (reverse bs, e)
58
59 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
60 collectAnnValBinders expr = go [] expr
61   where
62     go bs (_, AnnLam b e) | isId b = go (b:bs) e
63     go bs e                        = (reverse bs, e)
64
65 isAnnTypeArg :: AnnExpr b ann -> Bool
66 isAnnTypeArg (_, AnnType _) = True
67 isAnnTypeArg _              = False
68
69 dataConTagZ :: DataCon -> Int
70 dataConTagZ con = dataConTag con - fIRST_TAG
71
72 mkDataConTagLit :: DataCon -> Literal
73 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
74
75 mkDataConTag :: DataCon -> CoreExpr
76 mkDataConTag = mkIntLitInt . dataConTagZ
77
78 splitPrimTyCon :: Type -> Maybe TyCon
79 splitPrimTyCon ty
80   | Just (tycon, []) <- splitTyConApp_maybe ty
81   , isPrimTyCon tycon
82   = Just tycon
83
84   | otherwise = Nothing
85
86 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
87 mkBuiltinTyConApp get_tc tys
88   = do
89       tc <- builtin get_tc
90       return $ mkTyConApp tc tys
91
92 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
93 mkBuiltinTyConApps get_tc tys ty
94   = do
95       tc <- builtin get_tc
96       return $ foldr (mk tc) ty tys
97   where
98     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
99
100 {-
101 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
102 mkBuiltinTyConApps1 _      dft [] = return dft
103 mkBuiltinTyConApps1 get_tc _   tys
104   = do
105       tc <- builtin get_tc
106       case tys of
107         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
108         _  -> return $ foldr1 (mk tc) tys
109   where
110     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
111
112 mkClosureType :: Type -> Type -> VM Type
113 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
114 -}
115
116 mkClosureTypes :: [Type] -> Type -> VM Type
117 mkClosureTypes = mkBuiltinTyConApps closureTyCon
118
119 mkPReprType :: Type -> VM Type
120 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
121
122 mkPADictType :: Type -> VM Type
123 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
124
125 mkPArrayType :: Type -> VM Type
126 mkPArrayType ty
127   | Just tycon <- splitPrimTyCon ty
128   = do
129       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
130            $ lookupPrimPArray tycon
131       return $ mkTyConApp arr []
132 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
133
134 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
135 mkBuiltinCo get_tc
136   = do
137       tc <- builtin get_tc
138       return $ mkTyConApp tc []
139
140 parrayReprTyCon :: Type -> VM (TyCon, [Type])
141 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
142
143 parrayReprDataCon :: Type -> VM (DataCon, [Type])
144 parrayReprDataCon ty
145   = do
146       (tc, arg_tys) <- parrayReprTyCon ty
147       let [dc] = tyConDataCons tc
148       return (dc, arg_tys)
149
150 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
151 mkVScrut (ve, le)
152   = do
153       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
154       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
155
156 prDFunOfTyCon :: TyCon -> VM CoreExpr
157 prDFunOfTyCon tycon
158   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
159
160 paDictArgType :: TyVar -> VM (Maybe Type)
161 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
162   where
163     go ty k | Just k' <- kindView k = go ty k'
164     go ty (FunTy k1 k2)
165       = do
166           tv   <- newTyVar FSLIT("a") k1
167           mty1 <- go (TyVarTy tv) k1
168           case mty1 of
169             Just ty1 -> do
170                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
171                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
172             Nothing  -> go ty k2
173
174     go ty k
175       | isLiftedTypeKind k
176       = liftM Just (mkPADictType ty)
177
178     go _ _ = return Nothing
179
180 paDictOfType :: Type -> VM CoreExpr
181 paDictOfType ty = paDictOfTyApp ty_fn ty_args
182   where
183     (ty_fn, ty_args) = splitAppTys ty
184
185 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
186 paDictOfTyApp ty_fn ty_args
187   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
188 paDictOfTyApp (TyVarTy tv) ty_args
189   = do
190       dfun <- maybeV (lookupTyVarPA tv)
191       paDFunApply dfun ty_args
192 paDictOfTyApp (TyConApp tc _) ty_args
193   = do
194       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
195       paDFunApply (Var dfun) ty_args
196 paDictOfTyApp ty _ = pprPanic "paDictOfTyApp" (ppr ty)
197
198 paDFunType :: TyCon -> VM Type
199 paDFunType tc
200   = do
201       margs <- mapM paDictArgType tvs
202       res   <- mkPADictType (mkTyConApp tc arg_tys)
203       return . mkForAllTys tvs
204              $ mkFunTys [arg | Just arg <- margs] res
205   where
206     tvs = tyConTyVars tc
207     arg_tys = mkTyVarTys tvs
208
209 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
210 paDFunApply dfun tys
211   = do
212       dicts <- mapM paDictOfType tys
213       return $ mkApps (mkTyApps dfun tys) dicts
214
215 type PAMethod = (Builtins -> Var, String)
216
217 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
218 pa_length    = (lengthPAVar,    "lengthPA")
219 pa_replicate = (replicatePAVar, "replicatePA")
220 pa_empty     = (emptyPAVar,     "emptyPA")
221 pa_pack      = (packPAVar,      "packPA")
222
223 paMethod :: PAMethod -> Type -> VM CoreExpr
224 paMethod (_method, name) ty
225   | Just tycon <- splitPrimTyCon ty
226   = do
227       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
228           $ lookupPrimMethod tycon name
229       return (Var fn)
230
231 paMethod (method, _name) ty
232   = do
233       fn   <- builtin method
234       dict <- paDictOfType ty
235       return $ mkApps (Var fn) [Type ty, dict]
236
237 mkPR :: Type -> VM CoreExpr
238 mkPR ty
239   = do
240       fn   <- builtin mkPRVar
241       dict <- paDictOfType ty
242       return $ mkApps (Var fn) [Type ty, dict]
243
244 lengthPA :: Type -> CoreExpr -> VM CoreExpr
245 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
246
247 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
248 replicatePA len x = liftM (`mkApps` [len,x])
249                           (paMethod pa_replicate (exprType x))
250
251 emptyPA :: Type -> VM CoreExpr
252 emptyPA = paMethod pa_empty
253
254 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
255 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
256                              (paMethod pa_pack ty)
257
258 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
259           -> VM CoreExpr
260 combinePA ty len sel is xs
261   = liftM (`mkApps` (len : sel : is : xs))
262           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
263   where
264     n = length xs
265
266 liftPA :: CoreExpr -> VM CoreExpr
267 liftPA x
268   = do
269       lc <- builtin liftingContext
270       replicatePA (Var lc) x
271
272 newLocalVVar :: FastString -> Type -> VM VVar
273 newLocalVVar fs vty
274   = do
275       lty <- mkPArrayType vty
276       vv  <- newLocalVar fs vty
277       lv  <- newLocalVar fs lty
278       return (vv,lv)
279
280 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
281 polyAbstract tvs p
282   = localV
283   $ do
284       mdicts <- mapM mk_dict_var tvs
285       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
286       p (mk_lams mdicts)
287   where
288     mk_dict_var tv = do
289                        r <- paDictArgType tv
290                        case r of
291                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
292                          Nothing -> return Nothing
293
294     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
295
296 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
297 polyApply expr tys
298   = do
299       dicts <- mapM paDictOfType tys
300       return $ expr `mkTyApps` tys `mkApps` dicts
301
302 polyVApply :: VExpr -> [Type] -> VM VExpr
303 polyVApply expr tys
304   = do
305       dicts <- mapM paDictOfType tys
306       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
307
308 hoistBinding :: Var -> CoreExpr -> VM ()
309 hoistBinding v e = updGEnv $ \env ->
310   env { global_bindings = (v,e) : global_bindings env }
311
312 hoistExpr :: FastString -> CoreExpr -> VM Var
313 hoistExpr fs expr
314   = do
315       var <- newLocalVar fs (exprType expr)
316       hoistBinding var expr
317       return var
318
319 hoistVExpr :: VExpr -> VM VVar
320 hoistVExpr (ve, le)
321   = do
322       fs <- getBindName
323       vv <- hoistExpr ('v' `consFS` fs) ve
324       lv <- hoistExpr ('l' `consFS` fs) le
325       return (vv, lv)
326
327 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
328 hoistPolyVExpr tvs p
329   = do
330       expr <- closedV . polyAbstract tvs $ \abstract ->
331               liftM (mapVect abstract) p
332       fn   <- hoistVExpr expr
333       polyVApply (vVar fn) (mkTyVarTys tvs)
334
335 takeHoisted :: VM [(Var, CoreExpr)]
336 takeHoisted
337   = do
338       env <- readGEnv id
339       setGEnv $ env { global_bindings = [] }
340       return $ global_bindings env
341
342 {-
343 boxExpr :: Type -> VExpr -> VM VExpr
344 boxExpr ty (vexpr, lexpr)
345   | Just (tycon, []) <- splitTyConApp_maybe ty
346   , isUnLiftedTyCon tycon
347   = do
348       r <- lookupBoxedTyCon tycon
349       case r of
350         Just tycon' -> let [dc] = tyConDataCons tycon'
351                        in
352                        return (mkConApp dc [vexpr], lexpr)
353         Nothing     -> return (vexpr, lexpr)
354 -}
355
356 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
357 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
358   = do
359       dict <- paDictOfType env_ty
360       mkv  <- builtin mkClosureVar
361       mkl  <- builtin mkClosurePVar
362       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
363               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
364
365 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
366 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
367   = do
368       vapply <- builtin applyClosureVar
369       lapply <- builtin applyClosurePVar
370       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
371               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
372
373 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
374 buildClosures _   _    [] _ mk_body
375   = mk_body
376 buildClosures tvs vars [arg_ty] res_ty mk_body
377   = buildClosure tvs vars arg_ty res_ty mk_body
378 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
379   = do
380       res_ty' <- mkClosureTypes arg_tys res_ty
381       arg <- newLocalVVar FSLIT("x") arg_ty
382       buildClosure tvs vars arg_ty res_ty'
383         . hoistPolyVExpr tvs
384         $ do
385             lc <- builtin liftingContext
386             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
387             return $ vLams lc (vars ++ [arg]) clo
388
389 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
390 --   where
391 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
392 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
393 --
394 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
395 buildClosure tvs vars arg_ty res_ty mk_body
396   = do
397       (env_ty, env, bind) <- buildEnv vars
398       env_bndr <- newLocalVVar FSLIT("env") env_ty
399       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
400
401       fn <- hoistPolyVExpr tvs
402           $ do
403               lc    <- builtin liftingContext
404               body  <- mk_body
405               body' <- bind (vVar env_bndr)
406                             (vVarApps lc body (vars ++ [arg_bndr]))
407               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
408
409       mkClosure arg_ty res_ty env_ty fn env
410
411 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
412 buildEnv vvs
413   = do
414       lc <- builtin liftingContext
415       let (ty, venv, vbind) = mkVectEnv tys vs
416       (lenv, lbind) <- mkLiftEnv lc tys ls
417       return (ty, (venv, lenv),
418               \(venv,lenv) (vbody,lbody) ->
419               do
420                 let vbody' = vbind venv vbody
421                 lbody' <- lbind lenv lbody
422                 return (vbody', lbody'))
423   where
424     (vs,ls) = unzip vvs
425     tys     = map idType vs
426
427 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
428 mkVectEnv []   []  = (unitTy, Var unitDataConId, \_ body -> body)
429 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
430 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
431                         \env body -> Case env (mkWildId ty) (exprType body)
432                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
433   where
434     ty = mkCoreTupTy tys
435
436 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
437 mkLiftEnv lc [ty] [v]
438   = return (Var v, \env body ->
439                    do
440                      len <- lengthPA ty (Var v)
441                      return . Let (NonRec v env)
442                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
443
444 -- NOTE: this transparently deals with empty environments
445 mkLiftEnv lc tys vs
446   = do
447       (env_tc, env_tyargs) <- parrayReprTyCon vty
448
449       bndrs <- if null vs then do
450                                  v <- newDummyVar unitTy
451                                  return [v]
452                           else return vs
453       let [env_con] = tyConDataCons env_tc
454           
455           env = Var (dataConWrapId env_con)
456                 `mkTyApps`  env_tyargs
457                 `mkApps`    (Var lc : args)
458
459           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
460                           in
461                           return $ Case scrut (mkWildId (exprType scrut))
462                                         (exprType body)
463                                         [(DataAlt env_con, lc : bndrs, body)]
464       return (env, bind)
465   where
466     vty = mkCoreTupTy tys
467
468     args  | null vs   = [Var unitDataConId]
469           | otherwise = map Var vs
470