Find the correct array type for primitive tycons
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   mkBuiltinCo,
8   mkPADictType, mkPArrayType, mkPReprType,
9
10   parrayReprTyCon, parrayReprDataCon, mkVScrut,
11   prDFunOfTyCon,
12   paDictArgType, paDictOfType, paDFunType,
13   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14   polyAbstract, polyApply, polyVApply,
15   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16   buildClosure, buildClosures,
17   mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id                 ( mkWildId )
35 import MkId               ( unwrapFamInstScrut )
36 import Name               ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim            ( intPrimTy )
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75   | Just (tc, [ty']) <- splitTyConApp_maybe ty
76   , tyConName tc == name
77   = ty'
78
79   | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84   , tyConName tc == name
85   = (ty1, ty2)
86
87   | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91   | Just (tc', tys) <- splitTyConApp_maybe ty
92   , tc == tc'
93   = tys
94
95   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 splitPrimTyCon :: Type -> Maybe TyCon
104 splitPrimTyCon ty
105   | Just (tycon, []) <- splitTyConApp_maybe ty
106   , isPrimTyCon tycon
107   = Just tycon
108
109   | otherwise = Nothing
110
111 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
112 mkBuiltinTyConApp get_tc tys
113   = do
114       tc <- builtin get_tc
115       return $ mkTyConApp tc tys
116
117 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
118 mkBuiltinTyConApps get_tc tys ty
119   = do
120       tc <- builtin get_tc
121       return $ foldr (mk tc) ty tys
122   where
123     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
124
125 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
126 mkBuiltinTyConApps1 get_tc dft [] = return dft
127 mkBuiltinTyConApps1 get_tc dft tys
128   = do
129       tc <- builtin get_tc
130       case tys of
131         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
132         _  -> return $ foldr1 (mk tc) tys
133   where
134     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
135
136 mkClosureType :: Type -> Type -> VM Type
137 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
138
139 mkClosureTypes :: [Type] -> Type -> VM Type
140 mkClosureTypes = mkBuiltinTyConApps closureTyCon
141
142 mkPReprType :: Type -> VM Type
143 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
144
145 mkPADictType :: Type -> VM Type
146 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
147
148 mkPArrayType :: Type -> VM Type
149 mkPArrayType ty
150   | Just tycon <- splitPrimTyCon ty
151   = do
152       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
153            $ lookupPrimPArray tycon
154       return $ mkTyConApp arr []
155 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
156
157 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
158 mkBuiltinCo get_tc
159   = do
160       tc <- builtin get_tc
161       return $ mkTyConApp tc []
162
163 parrayReprTyCon :: Type -> VM (TyCon, [Type])
164 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
165
166 parrayReprDataCon :: Type -> VM (DataCon, [Type])
167 parrayReprDataCon ty
168   = do
169       (tc, arg_tys) <- parrayReprTyCon ty
170       let [dc] = tyConDataCons tc
171       return (dc, arg_tys)
172
173 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
174 mkVScrut (ve, le)
175   = do
176       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
177       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
178
179 prDFunOfTyCon :: TyCon -> VM CoreExpr
180 prDFunOfTyCon tycon
181   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
182
183 paDictArgType :: TyVar -> VM (Maybe Type)
184 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
185   where
186     go ty k | Just k' <- kindView k = go ty k'
187     go ty (FunTy k1 k2)
188       = do
189           tv   <- newTyVar FSLIT("a") k1
190           mty1 <- go (TyVarTy tv) k1
191           case mty1 of
192             Just ty1 -> do
193                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
194                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
195             Nothing  -> go ty k2
196
197     go ty k
198       | isLiftedTypeKind k
199       = liftM Just (mkPADictType ty)
200
201     go ty k = return Nothing
202
203 paDictOfType :: Type -> VM CoreExpr
204 paDictOfType ty = paDictOfTyApp ty_fn ty_args
205   where
206     (ty_fn, ty_args) = splitAppTys ty
207
208 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
209 paDictOfTyApp ty_fn ty_args
210   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
211 paDictOfTyApp (TyVarTy tv) ty_args
212   = do
213       dfun <- maybeV (lookupTyVarPA tv)
214       paDFunApply dfun ty_args
215 paDictOfTyApp (TyConApp tc _) ty_args
216   = do
217       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
218       paDFunApply (Var dfun) ty_args
219 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
220
221 paDFunType :: TyCon -> VM Type
222 paDFunType tc
223   = do
224       margs <- mapM paDictArgType tvs
225       res   <- mkPADictType (mkTyConApp tc arg_tys)
226       return . mkForAllTys tvs
227              $ mkFunTys [arg | Just arg <- margs] res
228   where
229     tvs = tyConTyVars tc
230     arg_tys = mkTyVarTys tvs
231
232 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
233 paDFunApply dfun tys
234   = do
235       dicts <- mapM paDictOfType tys
236       return $ mkApps (mkTyApps dfun tys) dicts
237
238 type PAMethod = (Builtins -> Var, String)
239
240 pa_length    = (lengthPAVar,    "lengthPA")
241 pa_replicate = (replicatePAVar, "replicatePA")
242 pa_empty     = (emptyPAVar,     "emptyPA")
243
244 paMethod :: PAMethod -> Type -> VM CoreExpr
245 paMethod (method, name) ty
246   | Just tycon <- splitPrimTyCon ty
247   = do
248       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
249           $ lookupPrimMethod tycon name
250       return (Var fn)
251
252 paMethod (method, name) ty
253   = do
254       fn   <- builtin method
255       dict <- paDictOfType ty
256       return $ mkApps (Var fn) [Type ty, dict]
257
258 mkPR :: Type -> VM CoreExpr
259 mkPR ty
260   = do
261       fn   <- builtin mkPRVar
262       dict <- paDictOfType ty
263       return $ mkApps (Var fn) [Type ty, dict]
264
265 lengthPA :: CoreExpr -> VM CoreExpr
266 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
267   where
268     ty = splitPArrayTy (exprType x)
269
270 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
271 replicatePA len x = liftM (`mkApps` [len,x])
272                           (paMethod pa_replicate (exprType x))
273
274 emptyPA :: Type -> VM CoreExpr
275 emptyPA = paMethod pa_empty
276
277 liftPA :: CoreExpr -> VM CoreExpr
278 liftPA x
279   = do
280       lc <- builtin liftingContext
281       replicatePA (Var lc) x
282
283 newLocalVVar :: FastString -> Type -> VM VVar
284 newLocalVVar fs vty
285   = do
286       lty <- mkPArrayType vty
287       vv  <- newLocalVar fs vty
288       lv  <- newLocalVar fs lty
289       return (vv,lv)
290
291 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
292 polyAbstract tvs p
293   = localV
294   $ do
295       mdicts <- mapM mk_dict_var tvs
296       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
297       p (mk_lams mdicts)
298   where
299     mk_dict_var tv = do
300                        r <- paDictArgType tv
301                        case r of
302                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
303                          Nothing -> return Nothing
304
305     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
306
307 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
308 polyApply expr tys
309   = do
310       dicts <- mapM paDictOfType tys
311       return $ expr `mkTyApps` tys `mkApps` dicts
312
313 polyVApply :: VExpr -> [Type] -> VM VExpr
314 polyVApply expr tys
315   = do
316       dicts <- mapM paDictOfType tys
317       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
318
319 hoistBinding :: Var -> CoreExpr -> VM ()
320 hoistBinding v e = updGEnv $ \env ->
321   env { global_bindings = (v,e) : global_bindings env }
322
323 hoistExpr :: FastString -> CoreExpr -> VM Var
324 hoistExpr fs expr
325   = do
326       var <- newLocalVar fs (exprType expr)
327       hoistBinding var expr
328       return var
329
330 hoistVExpr :: VExpr -> VM VVar
331 hoistVExpr (ve, le)
332   = do
333       fs <- getBindName
334       vv <- hoistExpr ('v' `consFS` fs) ve
335       lv <- hoistExpr ('l' `consFS` fs) le
336       return (vv, lv)
337
338 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
339 hoistPolyVExpr tvs p
340   = do
341       expr <- closedV . polyAbstract tvs $ \abstract ->
342               liftM (mapVect abstract) p
343       fn   <- hoistVExpr expr
344       polyVApply (vVar fn) (mkTyVarTys tvs)
345
346 takeHoisted :: VM [(Var, CoreExpr)]
347 takeHoisted
348   = do
349       env <- readGEnv id
350       setGEnv $ env { global_bindings = [] }
351       return $ global_bindings env
352
353 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
354 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
355   = do
356       dict <- paDictOfType env_ty
357       mkv  <- builtin mkClosureVar
358       mkl  <- builtin mkClosurePVar
359       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
360               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
361
362 mkClosureApp :: VExpr -> VExpr -> VM VExpr
363 mkClosureApp (vclo, lclo) (varg, larg)
364   = do
365       vapply <- builtin applyClosureVar
366       lapply <- builtin applyClosurePVar
367       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
368               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
369   where
370     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
371
372 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
373 buildClosures tvs vars [] res_ty mk_body
374   = mk_body
375 buildClosures tvs vars [arg_ty] res_ty mk_body
376   = buildClosure tvs vars arg_ty res_ty mk_body
377 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
378   = do
379       res_ty' <- mkClosureTypes arg_tys res_ty
380       arg <- newLocalVVar FSLIT("x") arg_ty
381       buildClosure tvs vars arg_ty res_ty'
382         . hoistPolyVExpr tvs
383         $ do
384             lc <- builtin liftingContext
385             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
386             return $ vLams lc (vars ++ [arg]) clo
387
388 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
389 --   where
390 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
391 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
392 --
393 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
394 buildClosure tvs vars arg_ty res_ty mk_body
395   = do
396       (env_ty, env, bind) <- buildEnv vars
397       env_bndr <- newLocalVVar FSLIT("env") env_ty
398       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
399
400       fn <- hoistPolyVExpr tvs
401           $ do
402               lc    <- builtin liftingContext
403               body  <- mk_body
404               body' <- bind (vVar env_bndr)
405                             (vVarApps lc body (vars ++ [arg_bndr]))
406               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
407
408       mkClosure arg_ty res_ty env_ty fn env
409
410 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
411 buildEnv vvs
412   = do
413       lc <- builtin liftingContext
414       let (ty, venv, vbind) = mkVectEnv tys vs
415       (lenv, lbind) <- mkLiftEnv lc tys ls
416       return (ty, (venv, lenv),
417               \(venv,lenv) (vbody,lbody) ->
418               do
419                 let vbody' = vbind venv vbody
420                 lbody' <- lbind lenv lbody
421                 return (vbody', lbody'))
422   where
423     (vs,ls) = unzip vvs
424     tys     = map idType vs
425
426 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
427 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
428 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
429 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
430                         \env body -> Case env (mkWildId ty) (exprType body)
431                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
432   where
433     ty = mkCoreTupTy tys
434
435 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
436 mkLiftEnv lc [ty] [v]
437   = return (Var v, \env body ->
438                    do
439                      len <- lengthPA (Var v)
440                      return . Let (NonRec v env)
441                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
442
443 -- NOTE: this transparently deals with empty environments
444 mkLiftEnv lc tys vs
445   = do
446       (env_tc, env_tyargs) <- parrayReprTyCon vty
447       let [env_con] = tyConDataCons env_tc
448           
449           env = Var (dataConWrapId env_con)
450                 `mkTyApps`  env_tyargs
451                 `mkVarApps` (lc : vs)
452
453           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
454                           in
455                           return $ Case scrut (mkWildId (exprType scrut))
456                                         (exprType body)
457                                         [(DataAlt env_con, lc : bndrs, body)]
458       return (env, bind)
459   where
460     vty = mkCoreTupTy tys
461
462     bndrs | null vs   = [mkWildId unitTy]
463           | otherwise = vs
464