Encode generic representation of vectorised TyCons by a data type
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   TyConRepr(..), mkTyConRepr,
8   mkToPRepr, mkToArrPRepr, mkFromPRepr, mkFromArrPRepr,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDictOfType, prCoerce,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import Name               ( Name )
38 import PrelNames
39 import TysWiredIn
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75   | Just (tc, [ty']) <- splitTyConApp_maybe ty
76   , tyConName tc == name
77   = ty'
78
79   | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84   , tyConName tc == name
85   = (ty1, ty2)
86
87   | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91   | Just (tc', tys) <- splitTyConApp_maybe ty
92   , tc == tc'
93   = tys
94
95   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitEmbedTy :: Type -> Type
98 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
99
100 splitClosureTy :: Type -> (Type, Type)
101 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
102
103 splitPArrayTy :: Type -> Type
104 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
105
106 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
107 mkBuiltinTyConApp get_tc tys
108   = do
109       tc <- builtin get_tc
110       return $ mkTyConApp tc tys
111
112 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
113 mkBuiltinTyConApps get_tc tys ty
114   = do
115       tc <- builtin get_tc
116       return $ foldr (mk tc) ty tys
117   where
118     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
119
120 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
121 mkBuiltinTyConApps1 get_tc dft [] = return dft
122 mkBuiltinTyConApps1 get_tc dft tys
123   = do
124       tc <- builtin get_tc
125       case tys of
126         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
127         _  -> return $ foldr1 (mk tc) tys
128   where
129     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
130
131 data TyConRepr = TyConRepr {
132                    repr_tyvars      :: [TyVar]
133                  , repr_tys         :: [[Type]]
134
135                  , repr_embed_tys   :: [[Type]]
136                  , repr_prod_tycons :: [Maybe TyCon]
137                  , repr_prod_tys    :: [Type]
138                  , repr_sum_tycon   :: Maybe TyCon
139                  , repr_type        :: Type
140                  }
141
142 mkTyConRepr :: TyCon -> VM TyConRepr
143 mkTyConRepr vect_tc
144   = do
145       embed_tys <- mapM (mapM mkEmbedType) rep_tys
146       prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
147       sum_tycon   <- mk_tycon sumTyCon rep_tys
148
149       let prod_tys = zipWith mk_tc_app_maybe prod_tycons embed_tys
150
151       return $ TyConRepr {
152                  repr_tyvars      = tyvars
153                , repr_tys         = rep_tys
154
155                , repr_embed_tys   = embed_tys
156                , repr_prod_tycons = prod_tycons
157                , repr_prod_tys    = prod_tys
158                , repr_sum_tycon   = sum_tycon
159                , repr_type        = mk_tc_app_maybe sum_tycon prod_tys
160                }
161   where
162     tyvars = tyConTyVars vect_tc
163     data_cons = tyConDataCons vect_tc
164     rep_tys   = map dataConRepArgTys data_cons
165
166     mk_tycon get_tc tys
167       | n > 1     = builtin (Just . get_tc n)
168       | otherwise = return Nothing
169       where n = length tys
170
171     mk_tc_app_maybe Nothing   []   = unitTy
172     mk_tc_app_maybe Nothing   [ty] = ty
173     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
174
175 {-
176 mkPRepr :: [[Type]] -> VM Type
177 mkPRepr tys
178   = do
179       embed_tc <- builtin embedTyCon
180       sum_tcs  <- builtins sumTyCon
181       prod_tcs <- builtins prodTyCon
182
183       let mk_sum []   = unitTy
184           mk_sum [ty] = ty
185           mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
186
187           mk_prod []   = unitTy
188           mk_prod [ty] = ty
189           mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
190
191           mk_embed ty = mkTyConApp embed_tc [ty]
192
193       return . mk_sum
194              . map (mk_prod . map mk_embed)
195              $ tys
196 -}
197
198 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
199 mkToPRepr ess
200   = do
201       embed_tc <- builtin embedTyCon
202       embed_dc <- builtin embedDataCon
203       sum_tcs  <- builtins sumTyCon
204       prod_tcs <- builtins prodTyCon
205
206       let mk_sum [] = ([Var unitDataConId], unitTy)
207           mk_sum [(expr, ty)] = ([expr], ty)
208           mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
209                        mkTyConApp sum_tc tys)
210             where
211               (exprs, tys)   = unzip es
212               sum_tc         = sum_tcs (length es)
213               mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
214
215           mk_prod [] = (Var unitDataConId, unitTy)
216           mk_prod [(expr, ty)] = (expr, ty)
217           mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
218                         mkTyConApp prod_tc tys)
219             where
220               (exprs, tys) = unzip es
221               prod_tc      = prod_tcs (length es)
222               [prod_dc]    = tyConDataCons prod_tc
223
224           mk_embed expr = (mkConApp embed_dc [Type ty, expr],
225                            mkTyConApp embed_tc [ty])
226             where ty = exprType expr
227
228       return . mk_sum $ map (mk_prod . map mk_embed) ess
229
230 mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
231 mkToArrPRepr len sel ess
232   = do
233       embed_tc <- builtin embedTyCon
234       (embed_rtc, _) <- parrayReprTyCon (mkTyConApp embed_tc [unitTy])
235       let [embed_rdc] = tyConDataCons embed_rtc
236
237       let mk_sum [(expr, ty)] = return (expr, ty)
238           mk_sum es
239             = do
240                 sum_tc <- builtin . sumTyCon $ length es
241                 (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
242                 let [sum_rdc] = tyConDataCons sum_rtc
243
244                 return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
245                         mkTyConApp sum_tc tys)
246             where
247               (exprs, tys) = unzip es
248
249           mk_prod [(expr, ty)] = return (expr, ty)
250           mk_prod es
251             = do
252                 prod_tc <- builtin . prodTyCon $ length es
253                 (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
254                 let [prod_rdc] = tyConDataCons prod_rtc
255
256                 return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
257                         mkTyConApp prod_tc tys)
258             where
259               (exprs, tys) = unzip es
260
261           mk_embed expr = (mkConApp embed_rdc [Type ty, expr],
262                            mkTyConApp embed_tc [ty])
263             where ty = splitPArrayTy (exprType expr)
264
265       liftM fst (mk_sum =<< mapM (mk_prod . map mk_embed) ess)
266
267 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
268 mkFromPRepr scrut res_ty alts
269   = do
270       embed_dc <- builtin embedDataCon
271       sum_tcs  <- builtins sumTyCon
272       prod_tcs <- builtins prodTyCon
273
274       let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
275           un_sum expr ty bs
276             = do
277                 ps     <- mapM (newLocalVar FSLIT("p")) tys
278                 bodies <- sequence
279                         $ zipWith4 un_prod (map Var ps) tys vars rs
280                 return . Case expr (mkWildId ty) res_ty
281                        $ zipWith3 mk_alt sum_dcs ps bodies
282             where
283               (vars, rs) = unzip bs
284               tys        = splitFixedTyConApp sum_tc ty
285               sum_tc     = sum_tcs $ length bs
286               sum_dcs    = tyConDataCons sum_tc
287
288               mk_alt dc p body = (DataAlt dc, [p], body)
289
290           un_prod expr ty []    r = return r
291           un_prod expr ty [var] r = return $ un_embed expr ty var r
292           un_prod expr ty vars  r
293             = do
294                 xs <- mapM (newLocalVar FSLIT("x")) tys
295                 let body = foldr (\(e,t,v) r -> un_embed e t v r) r
296                          $ zip3 (map Var xs) tys vars
297                 return $ Case expr (mkWildId ty) res_ty
298                          [(DataAlt prod_dc, xs, body)]
299             where
300               tys       = splitFixedTyConApp prod_tc ty
301               prod_tc   = prod_tcs $ length vars
302               [prod_dc] = tyConDataCons prod_tc
303
304           un_embed expr ty var r
305             = Case expr (mkWildId ty) res_ty
306                 [(DataAlt embed_dc, [var], r)]
307
308       un_sum scrut (exprType scrut) alts
309
310 mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
311                -> VM CoreExpr
312 mkFromArrPRepr scrut res_ty len sel vars res
313   = return (Var unitDataConId)
314
315 mkEmbedType :: Type -> VM Type
316 mkEmbedType ty = mkBuiltinTyConApp embedTyCon [ty]
317
318 mkClosureType :: Type -> Type -> VM Type
319 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
320
321 mkClosureTypes :: [Type] -> Type -> VM Type
322 mkClosureTypes = mkBuiltinTyConApps closureTyCon
323
324 mkPReprType :: Type -> VM Type
325 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
326
327 mkPADictType :: Type -> VM Type
328 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
329
330 mkPArrayType :: Type -> VM Type
331 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
332
333 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
334 parrayCoerce repr_tc args expr
335   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
336   = do
337       parray <- builtin parrayTyCon
338
339       let co = mkAppCoercion (mkTyConApp parray [])
340                              (mkSymCoercion (mkTyConApp arg_co args))
341
342       return $ mkCoerce co expr
343
344 parrayReprTyCon :: Type -> VM (TyCon, [Type])
345 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
346
347 parrayReprDataCon :: Type -> VM (DataCon, [Type])
348 parrayReprDataCon ty
349   = do
350       (tc, arg_tys) <- parrayReprTyCon ty
351       let [dc] = tyConDataCons tc
352       return (dc, arg_tys)
353
354 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
355 mkVScrut (ve, le)
356   = do
357       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
358       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
359
360 prDictOfType :: Type -> VM CoreExpr
361 prDictOfType orig_ty
362   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
363   = do
364       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
365       prDFunApply (Var dfun) ty_args
366
367 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
368 prDFunApply dfun tys
369   = do
370       args <- mapM mkDFunArg arg_tys
371       return $ mkApps mono_dfun args
372   where
373     mono_dfun    = mkTyApps dfun tys
374     (arg_tys, _) = splitFunTys (exprType mono_dfun)
375
376 mkDFunArg :: Type -> VM CoreExpr
377 mkDFunArg ty
378   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
379
380   = let name = tyConName tycon
381
382         get_dict | name == paTyConName = paDictOfType
383                  | name == prTyConName = prDictOfType
384                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
385
386     in get_dict arg
387
388 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
389
390 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
391 prCoerce repr_tc args expr
392   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
393   = do
394       pr_tc <- builtin prTyCon
395
396       let co = mkAppCoercion (mkTyConApp pr_tc [])
397                              (mkSymCoercion (mkTyConApp arg_co args))
398
399       return $ mkCoerce co expr
400
401 paDictArgType :: TyVar -> VM (Maybe Type)
402 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
403   where
404     go ty k | Just k' <- kindView k = go ty k'
405     go ty (FunTy k1 k2)
406       = do
407           tv   <- newTyVar FSLIT("a") k1
408           mty1 <- go (TyVarTy tv) k1
409           case mty1 of
410             Just ty1 -> do
411                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
412                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
413             Nothing  -> go ty k2
414
415     go ty k
416       | isLiftedTypeKind k
417       = liftM Just (mkPADictType ty)
418
419     go ty k = return Nothing
420
421 paDictOfType :: Type -> VM CoreExpr
422 paDictOfType ty = paDictOfTyApp ty_fn ty_args
423   where
424     (ty_fn, ty_args) = splitAppTys ty
425
426 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
427 paDictOfTyApp ty_fn ty_args
428   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
429 paDictOfTyApp (TyVarTy tv) ty_args
430   = do
431       dfun <- maybeV (lookupTyVarPA tv)
432       paDFunApply dfun ty_args
433 paDictOfTyApp (TyConApp tc _) ty_args
434   = do
435       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
436       paDFunApply (Var dfun) ty_args
437 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
438
439 paDFunType :: TyCon -> VM Type
440 paDFunType tc
441   = do
442       margs <- mapM paDictArgType tvs
443       res   <- mkPADictType (mkTyConApp tc arg_tys)
444       return . mkForAllTys tvs
445              $ mkFunTys [arg | Just arg <- margs] res
446   where
447     tvs = tyConTyVars tc
448     arg_tys = mkTyVarTys tvs
449
450 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
451 paDFunApply dfun tys
452   = do
453       dicts <- mapM paDictOfType tys
454       return $ mkApps (mkTyApps dfun tys) dicts
455
456 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
457 paMethod method ty
458   = do
459       fn   <- builtin method
460       dict <- paDictOfType ty
461       return $ mkApps (Var fn) [Type ty, dict]
462
463 lengthPA :: CoreExpr -> VM CoreExpr
464 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
465   where
466     ty = splitPArrayTy (exprType x)
467
468 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
469 replicatePA len x = liftM (`mkApps` [len,x])
470                           (paMethod replicatePAVar (exprType x))
471
472 emptyPA :: Type -> VM CoreExpr
473 emptyPA = paMethod emptyPAVar
474
475 liftPA :: CoreExpr -> VM CoreExpr
476 liftPA x
477   = do
478       lc <- builtin liftingContext
479       replicatePA (Var lc) x
480
481 newLocalVVar :: FastString -> Type -> VM VVar
482 newLocalVVar fs vty
483   = do
484       lty <- mkPArrayType vty
485       vv  <- newLocalVar fs vty
486       lv  <- newLocalVar fs lty
487       return (vv,lv)
488
489 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
490 polyAbstract tvs p
491   = localV
492   $ do
493       mdicts <- mapM mk_dict_var tvs
494       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
495       p (mk_lams mdicts)
496   where
497     mk_dict_var tv = do
498                        r <- paDictArgType tv
499                        case r of
500                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
501                          Nothing -> return Nothing
502
503     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
504
505 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
506 polyApply expr tys
507   = do
508       dicts <- mapM paDictOfType tys
509       return $ expr `mkTyApps` tys `mkApps` dicts
510
511 polyVApply :: VExpr -> [Type] -> VM VExpr
512 polyVApply expr tys
513   = do
514       dicts <- mapM paDictOfType tys
515       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
516
517 hoistBinding :: Var -> CoreExpr -> VM ()
518 hoistBinding v e = updGEnv $ \env ->
519   env { global_bindings = (v,e) : global_bindings env }
520
521 hoistExpr :: FastString -> CoreExpr -> VM Var
522 hoistExpr fs expr
523   = do
524       var <- newLocalVar fs (exprType expr)
525       hoistBinding var expr
526       return var
527
528 hoistVExpr :: VExpr -> VM VVar
529 hoistVExpr (ve, le)
530   = do
531       fs <- getBindName
532       vv <- hoistExpr ('v' `consFS` fs) ve
533       lv <- hoistExpr ('l' `consFS` fs) le
534       return (vv, lv)
535
536 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
537 hoistPolyVExpr tvs p
538   = do
539       expr <- closedV . polyAbstract tvs $ \abstract ->
540               liftM (mapVect abstract) p
541       fn   <- hoistVExpr expr
542       polyVApply (vVar fn) (mkTyVarTys tvs)
543
544 takeHoisted :: VM [(Var, CoreExpr)]
545 takeHoisted
546   = do
547       env <- readGEnv id
548       setGEnv $ env { global_bindings = [] }
549       return $ global_bindings env
550
551 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
552 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
553   = do
554       dict <- paDictOfType env_ty
555       mkv  <- builtin mkClosureVar
556       mkl  <- builtin mkClosurePVar
557       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
558               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
559
560 mkClosureApp :: VExpr -> VExpr -> VM VExpr
561 mkClosureApp (vclo, lclo) (varg, larg)
562   = do
563       vapply <- builtin applyClosureVar
564       lapply <- builtin applyClosurePVar
565       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
566               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
567   where
568     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
569
570 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
571 buildClosures tvs vars [] res_ty mk_body
572   = mk_body
573 buildClosures tvs vars [arg_ty] res_ty mk_body
574   = buildClosure tvs vars arg_ty res_ty mk_body
575 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
576   = do
577       res_ty' <- mkClosureTypes arg_tys res_ty
578       arg <- newLocalVVar FSLIT("x") arg_ty
579       buildClosure tvs vars arg_ty res_ty'
580         . hoistPolyVExpr tvs
581         $ do
582             lc <- builtin liftingContext
583             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
584             return $ vLams lc (vars ++ [arg]) clo
585
586 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
587 --   where
588 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
589 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
590 --
591 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
592 buildClosure tvs vars arg_ty res_ty mk_body
593   = do
594       (env_ty, env, bind) <- buildEnv vars
595       env_bndr <- newLocalVVar FSLIT("env") env_ty
596       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
597
598       fn <- hoistPolyVExpr tvs
599           $ do
600               lc    <- builtin liftingContext
601               body  <- mk_body
602               body' <- bind (vVar env_bndr)
603                             (vVarApps lc body (vars ++ [arg_bndr]))
604               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
605
606       mkClosure arg_ty res_ty env_ty fn env
607
608 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
609 buildEnv vvs
610   = do
611       lc <- builtin liftingContext
612       let (ty, venv, vbind) = mkVectEnv tys vs
613       (lenv, lbind) <- mkLiftEnv lc tys ls
614       return (ty, (venv, lenv),
615               \(venv,lenv) (vbody,lbody) ->
616               do
617                 let vbody' = vbind venv vbody
618                 lbody' <- lbind lenv lbody
619                 return (vbody', lbody'))
620   where
621     (vs,ls) = unzip vvs
622     tys     = map idType vs
623
624 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
625 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
626 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
627 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
628                         \env body -> Case env (mkWildId ty) (exprType body)
629                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
630   where
631     ty = mkCoreTupTy tys
632
633 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
634 mkLiftEnv lc [ty] [v]
635   = return (Var v, \env body ->
636                    do
637                      len <- lengthPA (Var v)
638                      return . Let (NonRec v env)
639                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
640
641 -- NOTE: this transparently deals with empty environments
642 mkLiftEnv lc tys vs
643   = do
644       (env_tc, env_tyargs) <- parrayReprTyCon vty
645       let [env_con] = tyConDataCons env_tc
646           
647           env = Var (dataConWrapId env_con)
648                 `mkTyApps`  env_tyargs
649                 `mkVarApps` (lc : vs)
650
651           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
652                           in
653                           return $ Case scrut (mkWildId (exprType scrut))
654                                         (exprType body)
655                                         [(DataAlt env_con, lc : bndrs, body)]
656       return (env, bind)
657   where
658     vty = mkCoreTupTy tys
659
660     bndrs | null vs   = [mkWildId unitTy]
661           | otherwise = vs
662