flatten, flattenExpr,
) where
--- standard
-import Monad (liftM, foldM)
+#include "HsVersions.h"
+
+-- friends
+import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
+ isLit, mkPArrTy, mkTuple, isSimpleExpr, substIdEnv)
+import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
+ liftVar, liftConst, intersectWithContext, mk'fst,
+ mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
+ mk'indexOfP,mk'eq,mk'neq)
-- GHC
-import CmdLineOpts (opt_Flatten)
+import TcType ( tcIsForAllTy, tcView )
+import TypeRep ( Type(..) )
+import StaticFlags (opt_Flatten)
import Panic (panic)
import ErrUtils (dumpIfSet_dyn)
-import UniqSupply (UniqSupply, mkSplitUniqSupply)
-import CmdLineOpts (DynFlag(..), DynFlags)
+import UniqSupply (mkSplitUniqSupply)
+import DynFlags (DynFlag(..))
import Literal (Literal, literalType)
-import Var (Var(..),TyVar)
+import Var (Var(..), idType, isTyVar)
+import Id (setIdType)
import DataCon (DataCon, dataConTag)
-import TypeRep (Type(..))
-import Type (isTypeKind)
-import HscTypes (HomeSymbolTable, PersistentCompilerState, ModDetails(..))
+import HscTypes ( ModGuts(..), ModGuts, HscEnv(..), hscEPS )
import CoreFVs (exprFreeVars)
import CoreSyn (Expr(..), Bind(..), Alt(..), AltCon(..), Note(..),
- CoreBndr, CoreExpr, CoreBind, CoreAlt, mkLams, mkLets,
+ CoreBndr, CoreExpr, CoreBind, mkLams, mkLets,
mkApps, mkIntLitInt)
import PprCore (pprCoreExpr)
import CoreLint (showPass, endPass)
import CoreUtils (exprType, applyTypeToArg, mkPiType)
-import VarEnv (IdEnv, mkVarEnv, zipVarEnv, extendVarEnv)
+import VarEnv (zipVarEnv)
import TysWiredIn (mkTupleTy)
import BasicTypes (Boxity(..))
-import Outputable (showSDoc, Outputable(..))
+import Outputable
import FastString
--- friends
-import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
- isLit, mkPArrTy, mkTuple, isSimpleExpr, boolTy, substIdEnv)
-import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
- liftVar, liftConst, intersectWithContext, mk'fst,
- mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
- mk'indexOfP,mk'eq,mk'neq)
-- FIXME: fro debugging - remove this
-import IOExts (trace)
-
-
-#include "HsVersions.h"
+import TRACE (trace)
+-- standard
+import Monad (liftM, foldM)
-- toplevel transformation
-- -----------------------
-- entry point to the flattening transformation for the compiler driver when
-- compiling a complete module (EXPORTED)
--
-flatten :: DynFlags
- -> PersistentCompilerState
- -> HomeSymbolTable
- -> ModDetails -- the module to be flattened
- -> IO ModDetails
-flatten dflags pcs hst modDetails@(ModDetails {md_binds = binds})
- | not opt_Flatten = return modDetails -- skip without -fflatten
+flatten :: HscEnv
+ -> ModGuts
+ -> IO ModGuts
+flatten hsc_env mod_impl@(ModGuts {mg_binds = binds})
+ | not opt_Flatten = return mod_impl -- skip without -fflatten
| otherwise =
do
+ let dflags = hsc_dflags hsc_env
+
+ eps <- hscEPS hsc_env
us <- mkSplitUniqSupply 'l' -- 'l' as in fLattening
--
-- announce vectorisation
--
-- vectorise all toplevel bindings
--
- let binds' = runFlatten pcs hst us $ vectoriseTopLevelBinds binds
+ let binds' = runFlatten hsc_env eps us $ vectoriseTopLevelBinds binds
--
-- and dump the result if requested
--
endPass dflags "Flattening [first phase: vectorisation]"
Opt_D_dump_vect binds'
- return $ modDetails {md_binds = binds'}
+ return $ mod_impl {mg_binds = binds'}
-- entry point to the flattening transformation for the compiler driver when
-- compiling a single expression in interactive mode (EXPORTED)
--
-flattenExpr :: DynFlags
- -> PersistentCompilerState
- -> HomeSymbolTable
+flattenExpr :: HscEnv
-> CoreExpr -- the expression to be flattened
-> IO CoreExpr
-flattenExpr dflags pcs hst expr
+flattenExpr hsc_env expr
| not opt_Flatten = return expr -- skip without -fflatten
| otherwise =
do
+ let dflags = hsc_dflags hsc_env
+ eps <- hscEPS hsc_env
+
us <- mkSplitUniqSupply 'l' -- 'l' as in fLattening
--
-- announce vectorisation
--
-- vectorise the expression
--
- let expr' = fst . runFlatten pcs hst us $ vectorise expr
+ let expr' = fst . runFlatten hsc_env eps us $ vectorise expr
--
-- and dump the result if requested
--
vectoriseOne (b, expr) =
do
(vexpr, ty) <- vectorise expr
- return (b{varType = ty}, vexpr)
+ return (setIdType b ty, vexpr)
-- Searches for function definitions and creates a lifted version for
vectorise:: CoreExpr -> Flatten (CoreExpr, Type)
vectorise (Var id) =
do
- let varTy = varType id
+ let varTy = idType id
let vecTy = vectoriseTy varTy
- return ((Var id{varType = vecTy}), vecTy)
+ return (Var (setIdType id vecTy), vecTy)
vectorise (Lit lit) =
return ((Lit lit), literalType lit)
do
(varg, argTy) <- vectorise arg
(vexpr, vexprTy) <- vectorise expr
- let vb = b{varType = argTy}
+ let vb = setIdType b argTy
return ((App (Lam vb vexpr) varg),
applyTypeToArg (mkPiType vb vexprTy) varg)
(vexpr, vexprTy) <- vectorise expr
(varg, vargTy) <- vectorise arg
- if (isPolyType vexprTy)
+ if (tcIsForAllTy vexprTy)
then do
let resTy = applyTypeToArg vexprTy varg
return (App vexpr varg, resTy)
let resTy = applyTypeToArg t1 varg
return ((App vexpr' varg), resTy) -- apply the first component of
-- the vectorized function
- where
- isPolyType t =
- (case t of
- (ForAllTy _ _) -> True
- (NoteTy _ nt) -> isPolyType nt
- _ -> False)
-
vectorise e@(Lam b expr)
- | isTypeKind (varType b) =
- do
+ | isTyVar b
+ = do
(vexpr, vexprTy) <- vectorise expr -- don't vectorise 'b'!
return ((Lam b vexpr), mkPiType b vexprTy)
| otherwise =
do
(vexpr, vexprTy) <- vectorise expr
- let vb = b{varType = vectoriseTy (varType b)}
+ let vb = setIdType b (vectoriseTy (idType b))
let ve = Lam vb vexpr
(lexpr, lexprTy) <- lift e
let veTy = mkPiType vb vexprTy
(vbody, vbodyTy) <- vectorise body
return ((Let vbind vbody), vbodyTy)
-vectorise (Case expr b alts) =
+vectorise (Case expr b ty alts) =
do
(vexpr, vexprTy) <- vectorise expr
valts <- mapM vectorise' alts
- return (Case vexpr b{varType = vexprTy} (map fst valts), snd (head valts))
+ let res_ty = snd (head valts)
+ return (Case vexpr (setIdType b vexprTy) res_ty (map fst valts), res_ty)
where vectorise' (con, bs, expr) =
do
(vexpr, vexprTy) <- vectorise expr
-}
vectoriseTy :: Type -> Type
+vectoriseTy ty | Just ty' <- tcView ty = vectoriseTy ty'
+ -- Look through notes and synonyms
+ -- NB: This will discard notes and synonyms, of course
+ -- ToDo: retain somehow?
vectoriseTy t@(TyVarTy v) = t
vectoriseTy t@(AppTy t1 t2) =
AppTy (vectoriseTy t1) (vectoriseTy t2)
(liftTy t)]
vectoriseTy t@(ForAllTy v ty) =
ForAllTy v (vectoriseTy ty)
-vectoriseTy t@(NoteTy note ty) = -- FIXME: is the note still valid after
- NoteTy note (vectoriseTy ty) -- this or should we just throw it away
vectoriseTy t = t
-- on the *top level* (is this sufficient???)
liftTy:: Type -> Type
+liftTy ty | Just ty' <- tcView ty = liftTy ty'
liftTy (FunTy t1 t2) = FunTy (liftTy t1) (liftTy t2)
liftTy (ForAllTy tv t) = ForAllTy tv (liftTy t)
-liftTy (NoteTy n t) = NoteTy n $ liftTy t
liftTy t = mkPArrTy t
-- lift type, don't change name (incl unique) nor IdInfo. IdInfo looks ok,
-- but I'm not entirely sure about some fields (e.g., strictness info)
liftBinderType:: CoreBndr -> Flatten CoreBndr
-liftBinderType bndr = return $ bndr {varType = liftTy (varType bndr)}
+liftBinderType bndr = return $ setIdType bndr (liftTy (idType bndr))
-- lift: lifts an expression (a -> [:a:])
-- If the expression is a simple expression, it is treated like a constant
lift cExpr@(Var id) =
do
lVar@(Var lId) <- liftVar id
- return (lVar, varType lId)
+ return (lVar, idType lId)
lift cExpr@(Lit lit) =
do
lift (Lam b expr)
| isSimpleExpr expr = liftSimpleFun b expr
- | isTypeKind (varType b) =
+ | isTyVar b =
do
(lexpr, lexprTy) <- lift expr -- don't lift b!
return (Lam b lexpr, mkPiType b lexprTy)
-- otherwise (a) compute index vector for simpleAlts (for def permute
-- later on
-- (b)
-lift cExpr@(Case expr b alts) =
+-- gaw 2004 FIX?
+lift cExpr@(Case expr b _ alts) =
do
(lExpr, _) <- lift expr
lb <- liftBinderType b -- lift alt-expression
liftSingleDataCon b dcon bnds expr =
do
let dconId = dataConTag dcon
- indexExpr <- mkIndexOfExprDCon (varType b) b dconId
+ indexExpr <- mkIndexOfExprDCon (idType b) b dconId
(bb, bbind) <- mkBind FSLIT("is") indexExpr
lbnds <- mapM liftBinderType bnds
((lExpr, _), bnds') <- packContext bb (extendContext lbnds (lift expr))
liftCaseDataConDefault b (_, _, def) alts =
do
let dconIds = map (\(DataAlt d, _, _) -> dataConTag d) alts
- indexExpr <- mkIndexOfExprDConDft (varType b) b dconIds
+ indexExpr <- mkIndexOfExprDConDft (idType b) b dconIds
(bb, bbind) <- mkBind FSLIT("is") indexExpr
((lDef, _), bnds) <- packContext bb (lift def)
(_, vbind) <- mkBind FSLIT("r") lDef
liftCaseLitDefault b (_, _, def) alts =
do
let lits = map (\(LitAlt l, _, _) -> l) alts
- indexExpr <- mkIndexOfExprDft (varType b) b lits
+ indexExpr <- mkIndexOfExprDft (idType b) b lits
(bb, bbind) <- mkBind FSLIT("is") indexExpr
((lDef, _), bnds) <- packContext bb (lift def)
(_, vbind) <- mkBind FSLIT("r") lDef
Flatten (CoreBind, CoreBind, [CoreBind])
liftSingleCaseLit b lit expr =
do
- indexExpr <- mkIndexOfExpr (varType b) b lit -- (a)
+ indexExpr <- mkIndexOfExpr (idType b) b lit -- (a)
(bb, bbind) <- mkBind FSLIT("is") indexExpr
((lExpr, t), bnds) <- packContext bb (lift expr) -- (b)
(_, vbind) <- mkBind FSLIT("r") lExpr
let iVar = getVarOfBind i
let eVar = getVarOfBind e
let cVar = getVarOfBind cBind
- let ty = varType eVar
+ let ty = idType eVar
newBnd <- mkDftBackpermute ty iVar eVar cVar
((fBnd, restBnds), _) <- dftbpBinders' is es newBnd
return ((fBnd, (newBnd:restBnds)), liftTy ty)
dftbpBinders' _ _ _ =
- panic "Flattening.dftbpBinders: index and expression binder lists \
- \have different length!"
+ panic "Flattening.dftbpBinders: index and expression binder lists have different length!"
getExprOfBind:: CoreBind -> CoreExpr
getExprOfBind (NonRec _ expr) = expr
do
bndVars <- collectBoundVars expr
let bndVars' = b:bndVars
- bndVarsTuple = mkTuple (map varType bndVars') (map Var bndVars')
+ bndVarsTuple = mkTuple (map idType bndVars') (map Var bndVars')
lamExpr = mkLams (b:bndVars) expr -- FIXME: should be tuple
-- here
let (t1, t2) = funTyArgs . exprType $ lamExpr
-- indexOf (mapP (\x -> x == lit) b) b
--
mkIndexOfExpr:: Type -> CoreBndr -> Literal -> Flatten CoreExpr
-mkIndexOfExpr varType b lit =
+mkIndexOfExpr idType b lit =
do
- eqExpr <- mk'eq varType (Var b) (Lit lit)
+ eqExpr <- mk'eq idType (Var b) (Lit lit)
let lambdaExpr = (Lam b eqExpr)
- mk'indexOfP varType lambdaExpr (Var b)
+ mk'indexOfP idType lambdaExpr (Var b)
-- there is FlattenMonad.mk'indexOfP as well as
-- CoreSyn.mkApps and CoreSyn.mkLam, all of which should help here
-- indexOfP (\x -> x == dconId) b)
--
mkIndexOfExprDCon::Type -> CoreBndr -> Int -> Flatten CoreExpr
-mkIndexOfExprDCon varType b dId =
+mkIndexOfExprDCon idType b dId =
do
let intExpr = mkIntLitInt dId
- eqExpr <- mk'eq varType (Var b) intExpr
+ eqExpr <- mk'eq idType (Var b) intExpr
let lambdaExpr = (Lam b intExpr)
- mk'indexOfP varType lambdaExpr (Var b)
+ mk'indexOfP idType lambdaExpr (Var b)
-- indexOfP (\x -> x != dconId_1 && ....) b)
--
mkIndexOfExprDConDft:: Type -> CoreBndr -> [Int] -> Flatten CoreExpr
-mkIndexOfExprDConDft varType b dId =
+mkIndexOfExprDConDft idType b dId =
do
let intExprs = map mkIntLitInt dId
- bExpr <- foldM (mk'neq varType) (head intExprs) (tail intExprs)
+ bExpr <- foldM (mk'neq idType) (head intExprs) (tail intExprs)
let lambdaExpr = (Lam b bExpr)
- mk'indexOfP varType (Var b) bExpr
+ mk'indexOfP idType (Var b) bExpr
-- mkIndexOfExprDef b [lit1, lit2,...] ->
-- indexOf (\x -> not (x == lit1 || x == lit2 ....) b
mkIndexOfExprDft:: Type -> CoreBndr -> [Literal] -> Flatten CoreExpr
-mkIndexOfExprDft varType b lits =
+mkIndexOfExprDft idType b lits =
do
let litExprs = map (\l-> Lit l) lits
- bExpr <- foldM (mk'neq varType) (head litExprs) (tail litExprs)
+ bExpr <- foldM (mk'neq idType) (head litExprs) (tail litExprs)
let lambdaExpr = (Lam b bExpr)
- mk'indexOfP varType bExpr (Var b)
+ mk'indexOfP idType bExpr (Var b)
-- create a back-permute binder
--
--- * `mkDftBackpermute ty indexArrayVar srcArrayVar dftArrayVar' creates a
+-- * `mkDftBackpermute ty indexArrayVar srcArrayVar dftArrayVar' creates a
-- Core binding of the form
--
-- x = bpermuteDftP indexArrayVar srcArrayVar dftArrayVar
where showBinds (NonRec b e) = showBind (b,e)
showBinds (Rec bnds) = concat (map showBind bnds)
showBind (b,e) = " b = " ++ (showCoreExpr e)++ "\n"
-showCoreExpr (Case ex b alts) =
+-- gaw 2004 FIX?
+showCoreExpr (Case ex b ty alts) =
"Case b = " ++ (showCoreExpr ex) ++ " of \n" ++ (showAlts alts)
where showAlts _ = ""
showCoreExpr (Note _ ex) = "Note n " ++ (showCoreExpr ex)