Vectorisation of enumeration types
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag, mkDataConTagLit,
5   splitClosureTy,
6
7   mkBuiltinCo,
8   mkPADictType, mkPArrayType, mkPReprType,
9
10   parrayReprTyCon, parrayReprDataCon, mkVScrut,
11   prDFunOfTyCon,
12   paDictArgType, paDictOfType, paDFunType,
13   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14   polyAbstract, polyApply, polyVApply,
15   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16   buildClosure, buildClosures,
17   mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id                 ( mkWildId )
35 import MkId               ( unwrapFamInstScrut )
36 import Name               ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim            ( intPrimTy )
40 import BasicTypes         ( Boxity(..) )
41 import Literal            ( Literal, mkMachInt )
42
43 import Outputable
44 import FastString
45
46 import Data.List             ( zipWith4 )
47 import Control.Monad         ( liftM, liftM2, zipWithM_ )
48
49 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
50 collectAnnTypeArgs expr = go expr []
51   where
52     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
53     go e                             tys = (e, tys)
54
55 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnTypeBinders expr = go [] expr
57   where
58     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
59     go bs e                           = (reverse bs, e)
60
61 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnValBinders expr = go [] expr
63   where
64     go bs (_, AnnLam b e) | isId b = go (b:bs) e
65     go bs e                        = (reverse bs, e)
66
67 isAnnTypeArg :: AnnExpr b ann -> Bool
68 isAnnTypeArg (_, AnnType t) = True
69 isAnnTypeArg _              = False
70
71 mkDataConTagLit :: DataCon -> Literal
72 mkDataConTagLit con
73   = mkMachInt . toInteger $ dataConTag con - fIRST_TAG
74
75 mkDataConTag :: DataCon -> CoreExpr
76 mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG)
77
78 splitUnTy :: String -> Name -> Type -> Type
79 splitUnTy s name ty
80   | Just (tc, [ty']) <- splitTyConApp_maybe ty
81   , tyConName tc == name
82   = ty'
83
84   | otherwise = pprPanic s (ppr ty)
85
86 splitBinTy :: String -> Name -> Type -> (Type, Type)
87 splitBinTy s name ty
88   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
89   , tyConName tc == name
90   = (ty1, ty2)
91
92   | otherwise = pprPanic s (ppr ty)
93
94 splitFixedTyConApp :: TyCon -> Type -> [Type]
95 splitFixedTyConApp tc ty
96   | Just (tc', tys) <- splitTyConApp_maybe ty
97   , tc == tc'
98   = tys
99
100   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
101
102 splitClosureTy :: Type -> (Type, Type)
103 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
104
105 splitPArrayTy :: Type -> Type
106 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
107
108 splitPrimTyCon :: Type -> Maybe TyCon
109 splitPrimTyCon ty
110   | Just (tycon, []) <- splitTyConApp_maybe ty
111   , isPrimTyCon tycon
112   = Just tycon
113
114   | otherwise = Nothing
115
116 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
117 mkBuiltinTyConApp get_tc tys
118   = do
119       tc <- builtin get_tc
120       return $ mkTyConApp tc tys
121
122 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
123 mkBuiltinTyConApps get_tc tys ty
124   = do
125       tc <- builtin get_tc
126       return $ foldr (mk tc) ty tys
127   where
128     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
129
130 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
131 mkBuiltinTyConApps1 get_tc dft [] = return dft
132 mkBuiltinTyConApps1 get_tc dft tys
133   = do
134       tc <- builtin get_tc
135       case tys of
136         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
137         _  -> return $ foldr1 (mk tc) tys
138   where
139     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
140
141 mkClosureType :: Type -> Type -> VM Type
142 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
143
144 mkClosureTypes :: [Type] -> Type -> VM Type
145 mkClosureTypes = mkBuiltinTyConApps closureTyCon
146
147 mkPReprType :: Type -> VM Type
148 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
149
150 mkPADictType :: Type -> VM Type
151 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
152
153 mkPArrayType :: Type -> VM Type
154 mkPArrayType ty
155   | Just tycon <- splitPrimTyCon ty
156   = do
157       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
158            $ lookupPrimPArray tycon
159       return $ mkTyConApp arr []
160 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
161
162 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
163 mkBuiltinCo get_tc
164   = do
165       tc <- builtin get_tc
166       return $ mkTyConApp tc []
167
168 parrayReprTyCon :: Type -> VM (TyCon, [Type])
169 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
170
171 parrayReprDataCon :: Type -> VM (DataCon, [Type])
172 parrayReprDataCon ty
173   = do
174       (tc, arg_tys) <- parrayReprTyCon ty
175       let [dc] = tyConDataCons tc
176       return (dc, arg_tys)
177
178 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
179 mkVScrut (ve, le)
180   = do
181       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
182       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
183
184 prDFunOfTyCon :: TyCon -> VM CoreExpr
185 prDFunOfTyCon tycon
186   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
187
188 paDictArgType :: TyVar -> VM (Maybe Type)
189 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
190   where
191     go ty k | Just k' <- kindView k = go ty k'
192     go ty (FunTy k1 k2)
193       = do
194           tv   <- newTyVar FSLIT("a") k1
195           mty1 <- go (TyVarTy tv) k1
196           case mty1 of
197             Just ty1 -> do
198                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
199                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
200             Nothing  -> go ty k2
201
202     go ty k
203       | isLiftedTypeKind k
204       = liftM Just (mkPADictType ty)
205
206     go ty k = return Nothing
207
208 paDictOfType :: Type -> VM CoreExpr
209 paDictOfType ty = paDictOfTyApp ty_fn ty_args
210   where
211     (ty_fn, ty_args) = splitAppTys ty
212
213 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
214 paDictOfTyApp ty_fn ty_args
215   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
216 paDictOfTyApp (TyVarTy tv) ty_args
217   = do
218       dfun <- maybeV (lookupTyVarPA tv)
219       paDFunApply dfun ty_args
220 paDictOfTyApp (TyConApp tc _) ty_args
221   = do
222       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
223       paDFunApply (Var dfun) ty_args
224 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
225
226 paDFunType :: TyCon -> VM Type
227 paDFunType tc
228   = do
229       margs <- mapM paDictArgType tvs
230       res   <- mkPADictType (mkTyConApp tc arg_tys)
231       return . mkForAllTys tvs
232              $ mkFunTys [arg | Just arg <- margs] res
233   where
234     tvs = tyConTyVars tc
235     arg_tys = mkTyVarTys tvs
236
237 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
238 paDFunApply dfun tys
239   = do
240       dicts <- mapM paDictOfType tys
241       return $ mkApps (mkTyApps dfun tys) dicts
242
243 type PAMethod = (Builtins -> Var, String)
244
245 pa_length    = (lengthPAVar,    "lengthPA")
246 pa_replicate = (replicatePAVar, "replicatePA")
247 pa_empty     = (emptyPAVar,     "emptyPA")
248
249 paMethod :: PAMethod -> Type -> VM CoreExpr
250 paMethod (method, name) ty
251   | Just tycon <- splitPrimTyCon ty
252   = do
253       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
254           $ lookupPrimMethod tycon name
255       return (Var fn)
256
257 paMethod (method, name) ty
258   = do
259       fn   <- builtin method
260       dict <- paDictOfType ty
261       return $ mkApps (Var fn) [Type ty, dict]
262
263 mkPR :: Type -> VM CoreExpr
264 mkPR ty
265   = do
266       fn   <- builtin mkPRVar
267       dict <- paDictOfType ty
268       return $ mkApps (Var fn) [Type ty, dict]
269
270 lengthPA :: CoreExpr -> VM CoreExpr
271 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
272   where
273     ty = splitPArrayTy (exprType x)
274
275 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
276 replicatePA len x = liftM (`mkApps` [len,x])
277                           (paMethod pa_replicate (exprType x))
278
279 emptyPA :: Type -> VM CoreExpr
280 emptyPA = paMethod pa_empty
281
282 liftPA :: CoreExpr -> VM CoreExpr
283 liftPA x
284   = do
285       lc <- builtin liftingContext
286       replicatePA (Var lc) x
287
288 newLocalVVar :: FastString -> Type -> VM VVar
289 newLocalVVar fs vty
290   = do
291       lty <- mkPArrayType vty
292       vv  <- newLocalVar fs vty
293       lv  <- newLocalVar fs lty
294       return (vv,lv)
295
296 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
297 polyAbstract tvs p
298   = localV
299   $ do
300       mdicts <- mapM mk_dict_var tvs
301       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
302       p (mk_lams mdicts)
303   where
304     mk_dict_var tv = do
305                        r <- paDictArgType tv
306                        case r of
307                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
308                          Nothing -> return Nothing
309
310     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
311
312 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
313 polyApply expr tys
314   = do
315       dicts <- mapM paDictOfType tys
316       return $ expr `mkTyApps` tys `mkApps` dicts
317
318 polyVApply :: VExpr -> [Type] -> VM VExpr
319 polyVApply expr tys
320   = do
321       dicts <- mapM paDictOfType tys
322       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
323
324 hoistBinding :: Var -> CoreExpr -> VM ()
325 hoistBinding v e = updGEnv $ \env ->
326   env { global_bindings = (v,e) : global_bindings env }
327
328 hoistExpr :: FastString -> CoreExpr -> VM Var
329 hoistExpr fs expr
330   = do
331       var <- newLocalVar fs (exprType expr)
332       hoistBinding var expr
333       return var
334
335 hoistVExpr :: VExpr -> VM VVar
336 hoistVExpr (ve, le)
337   = do
338       fs <- getBindName
339       vv <- hoistExpr ('v' `consFS` fs) ve
340       lv <- hoistExpr ('l' `consFS` fs) le
341       return (vv, lv)
342
343 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
344 hoistPolyVExpr tvs p
345   = do
346       expr <- closedV . polyAbstract tvs $ \abstract ->
347               liftM (mapVect abstract) p
348       fn   <- hoistVExpr expr
349       polyVApply (vVar fn) (mkTyVarTys tvs)
350
351 takeHoisted :: VM [(Var, CoreExpr)]
352 takeHoisted
353   = do
354       env <- readGEnv id
355       setGEnv $ env { global_bindings = [] }
356       return $ global_bindings env
357
358 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
359 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
360   = do
361       dict <- paDictOfType env_ty
362       mkv  <- builtin mkClosureVar
363       mkl  <- builtin mkClosurePVar
364       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
365               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
366
367 mkClosureApp :: VExpr -> VExpr -> VM VExpr
368 mkClosureApp (vclo, lclo) (varg, larg)
369   = do
370       vapply <- builtin applyClosureVar
371       lapply <- builtin applyClosurePVar
372       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
373               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
374   where
375     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
376
377 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
378 buildClosures tvs vars [] res_ty mk_body
379   = mk_body
380 buildClosures tvs vars [arg_ty] res_ty mk_body
381   = buildClosure tvs vars arg_ty res_ty mk_body
382 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
383   = do
384       res_ty' <- mkClosureTypes arg_tys res_ty
385       arg <- newLocalVVar FSLIT("x") arg_ty
386       buildClosure tvs vars arg_ty res_ty'
387         . hoistPolyVExpr tvs
388         $ do
389             lc <- builtin liftingContext
390             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
391             return $ vLams lc (vars ++ [arg]) clo
392
393 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
394 --   where
395 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
396 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
397 --
398 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
399 buildClosure tvs vars arg_ty res_ty mk_body
400   = do
401       (env_ty, env, bind) <- buildEnv vars
402       env_bndr <- newLocalVVar FSLIT("env") env_ty
403       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
404
405       fn <- hoistPolyVExpr tvs
406           $ do
407               lc    <- builtin liftingContext
408               body  <- mk_body
409               body' <- bind (vVar env_bndr)
410                             (vVarApps lc body (vars ++ [arg_bndr]))
411               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
412
413       mkClosure arg_ty res_ty env_ty fn env
414
415 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
416 buildEnv vvs
417   = do
418       lc <- builtin liftingContext
419       let (ty, venv, vbind) = mkVectEnv tys vs
420       (lenv, lbind) <- mkLiftEnv lc tys ls
421       return (ty, (venv, lenv),
422               \(venv,lenv) (vbody,lbody) ->
423               do
424                 let vbody' = vbind venv vbody
425                 lbody' <- lbind lenv lbody
426                 return (vbody', lbody'))
427   where
428     (vs,ls) = unzip vvs
429     tys     = map idType vs
430
431 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
432 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
433 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
434 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
435                         \env body -> Case env (mkWildId ty) (exprType body)
436                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
437   where
438     ty = mkCoreTupTy tys
439
440 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
441 mkLiftEnv lc [ty] [v]
442   = return (Var v, \env body ->
443                    do
444                      len <- lengthPA (Var v)
445                      return . Let (NonRec v env)
446                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
447
448 -- NOTE: this transparently deals with empty environments
449 mkLiftEnv lc tys vs
450   = do
451       (env_tc, env_tyargs) <- parrayReprTyCon vty
452       let [env_con] = tyConDataCons env_tc
453           
454           env = Var (dataConWrapId env_con)
455                 `mkTyApps`  env_tyargs
456                 `mkVarApps` (lc : vs)
457
458           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
459                           in
460                           return $ Case scrut (mkWildId (exprType scrut))
461                                         (exprType body)
462                                         [(DataAlt env_con, lc : bndrs, body)]
463       return (env, bind)
464   where
465     vty = mkCoreTupTy tys
466
467     bndrs | null vs   = [mkWildId unitTy]
468           | otherwise = vs
469