[project @ 2003-07-02 14:59:00 by simonpj]
[ghc-hetmet.git] / ghc / compiler / ndpFlatten / Flattening.hs
index 4733bc4..4f0f86b 100644 (file)
@@ -52,52 +52,49 @@ module Flattening (
   flatten, flattenExpr, 
 ) where 
 
--- standard
-import Monad        (liftM, foldM)
+#include "HsVersions.h"
+
+-- friends
+import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
+                    isLit, mkPArrTy, mkTuple, isSimpleExpr, substIdEnv)
+import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
+                    liftVar, liftConst, intersectWithContext, mk'fst,
+                    mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
+                    mk'indexOfP,mk'eq,mk'neq) 
 
 -- GHC
 import CmdLineOpts  (opt_Flatten)
 import Panic        (panic)
 import ErrUtils     (dumpIfSet_dyn)
-import UniqSupply   (UniqSupply, mkSplitUniqSupply)
-import CmdLineOpts  (DynFlag(..), DynFlags)
+import UniqSupply   (mkSplitUniqSupply)
+import CmdLineOpts  (DynFlag(..))
 import Literal      (Literal, literalType)
-import Var         (Var(..),TyVar)
+import Var         (Var(..))
 import DataCon     (DataCon, dataConTag)
 import TypeRep      (Type(..))
 import Type         (isTypeKind)
-import HscTypes            (HomeSymbolTable, PersistentCompilerState, ModDetails(..))
+import HscTypes            (PersistentCompilerState, ModGuts(..), 
+                    ModGuts, HscEnv(..) )
 import CoreFVs     (exprFreeVars)
 import CoreSyn     (Expr(..), Bind(..), Alt(..), AltCon(..), Note(..),
-                    CoreBndr, CoreExpr, CoreBind, CoreAlt, mkLams, mkLets,
+                    CoreBndr, CoreExpr, CoreBind, mkLams, mkLets,
                     mkApps, mkIntLitInt)  
 import PprCore      (pprCoreExpr)
 import CoreLint            (showPass, endPass)
 
 import CoreUtils    (exprType, applyTypeToArg, mkPiType)
-import VarEnv       (IdEnv, mkVarEnv, zipVarEnv, extendVarEnv)
+import VarEnv       (zipVarEnv)
 import TysWiredIn   (mkTupleTy)
 import BasicTypes   (Boxity(..))
-import Outputable   (showSDoc, Outputable(..))
-
+import Outputable
+import FastString
 
--- friends
-import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
-                    isLit, mkPArrTy, mkTuple, isSimpleExpr, boolTy, substIdEnv)
-import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
-                    liftVar, liftConst, intersectWithContext, mk'fst,
-                    mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
-                    mk'indexOfP,mk'eq,mk'neq) 
 
 -- FIXME: fro debugging - remove this
-import IOExts    (trace)
-
-
-#include "HsVersions.h"
-{-# INLINE slit #-}
-slit x = FastString.mkFastCharString# x
--- FIXME: SLIT() doesn't work for some strange reason
+import TRACE    (trace)
 
+-- standard
+import Monad        (liftM, foldM)
 
 -- toplevel transformation
 -- -----------------------
@@ -105,15 +102,16 @@ slit x = FastString.mkFastCharString# x
 -- entry point to the flattening transformation for the compiler driver when
 -- compiling a complete module (EXPORTED) 
 --
-flatten :: DynFlags 
+flatten :: HscEnv
        -> PersistentCompilerState 
-       -> HomeSymbolTable
-       -> ModDetails                   -- the module to be flattened
-       -> IO ModDetails
-flatten dflags pcs hst modDetails@(ModDetails {md_binds = binds}) 
-  | not opt_Flatten = return modDetails -- skip without -fflatten
+       -> ModGuts
+       -> IO ModGuts
+flatten hsc_env pcs mod_impl@(ModGuts {mg_binds = binds}) 
+  | not opt_Flatten = return mod_impl -- skip without -fflatten
   | otherwise       =
   do
+    let dflags = hsc_dflags hsc_env
+
     us <- mkSplitUniqSupply 'l'                -- 'l' as in fLattening
     --
     -- announce vectorisation
@@ -122,26 +120,27 @@ flatten dflags pcs hst modDetails@(ModDetails {md_binds = binds})
     --
     -- vectorise all toplevel bindings
     --
-    let binds' = runFlatten pcs hst us $ vectoriseTopLevelBinds binds
+    let binds' = runFlatten hsc_env pcs us $ vectoriseTopLevelBinds binds
     --
     -- and dump the result if requested
     --
     endPass dflags "Flattening [first phase: vectorisation]" 
            Opt_D_dump_vect binds'
-    return $ modDetails {md_binds = binds'}
+    return $ mod_impl {mg_binds = binds'}
 
 -- entry point to the flattening transformation for the compiler driver when
 -- compiling a single expression in interactive mode (EXPORTED) 
 --
-flattenExpr :: DynFlags 
+flattenExpr :: HscEnv
            -> PersistentCompilerState 
-           -> HomeSymbolTable 
            -> CoreExpr                 -- the expression to be flattened
            -> IO CoreExpr
-flattenExpr dflags pcs hst expr
+flattenExpr hsc_env pcs expr
   | not opt_Flatten = return expr       -- skip without -fflatten
   | otherwise       =
   do
+    let dflags = hsc_dflags hsc_env
+
     us <- mkSplitUniqSupply 'l'                -- 'l' as in fLattening
     --
     -- announce vectorisation
@@ -150,7 +149,7 @@ flattenExpr dflags pcs hst expr
     --
     -- vectorise the expression
     --
-    let expr' = fst . runFlatten pcs hst us $ vectorise expr
+    let expr' = fst . runFlatten hsc_env pcs us $ vectorise expr
     --
     -- and dump the result if requested
     --
@@ -505,10 +504,10 @@ liftSingleDataCon b dcon bnds expr =
   do 
     let dconId           = dataConTag dcon
     indexExpr           <- mkIndexOfExprDCon (varType b)  b dconId
-    (b', bbind)         <- mkBind (slit "is"#) indexExpr
+    (bb, bbind)         <- mkBind FSLIT("is") indexExpr
     lbnds               <- mapM liftBinderType bnds
-    ((lExpr, _), bnds') <- packContext  b' (extendContext lbnds (lift expr))
-    (_, vbind)          <- mkBind (slit "r"#) lExpr
+    ((lExpr, _), bnds') <- packContext  bb (extendContext lbnds (lift expr))
+    (_, vbind)          <- mkBind FSLIT("r") lExpr
     return (bbind, vbind, bnds')
 
 -- FIXME: clean this up. the datacon and the literal case are so
@@ -521,9 +520,9 @@ liftCaseDataConDefault b (_, _, def) alts =
   do
     let dconIds        = map (\(DataAlt d, _, _) -> dataConTag d) alts
     indexExpr         <- mkIndexOfExprDConDft (varType b) b dconIds
-    (b', bbind)       <- mkBind (slit "is"#) indexExpr
-    ((lDef, _), bnds) <- packContext  b' (lift def)     
-    (_, vbind)        <- mkBind (slit "r"#) lDef
+    (bb, bbind)       <- mkBind FSLIT("is") indexExpr
+    ((lDef, _), bnds) <- packContext  bb (lift def)     
+    (_, vbind)        <- mkBind FSLIT("r") lDef
     return (bbind, vbind, bnds)
 
 -- liftCaseLit: checks if we have a default case and handles it 
@@ -552,9 +551,9 @@ liftCaseLitDefault b (_, _, def) alts =
   do
     let lits           = map (\(LitAlt l, _, _) -> l) alts
     indexExpr         <- mkIndexOfExprDft (varType b) b lits
-    (b', bbind)       <- mkBind (slit "is"#) indexExpr
-    ((lDef, _), bnds) <- packContext  b' (lift def)     
-    (_, vbind)        <- mkBind (slit "r"#) lDef
+    (bb, bbind)       <- mkBind FSLIT("is") indexExpr
+    ((lDef, _), bnds) <- packContext  bb (lift def)     
+    (_, vbind)        <- mkBind FSLIT("r") lDef
     return (bbind, vbind, bnds)
 
 -- FIXME: 
@@ -591,9 +590,9 @@ liftSingleCaseLit:: CoreBndr -> Literal -> CoreExpr  ->
 liftSingleCaseLit b lit expr =
  do 
    indexExpr          <- mkIndexOfExpr (varType b) b lit -- (a)
-   (b', bbind)        <- mkBind (slit "is"#) indexExpr
-   ((lExpr, t), bnds) <- packContext  b' (lift expr)     -- (b)         
-   (_, vbind)         <- mkBind (slit "r"#) lExpr
+   (bb, bbind)        <- mkBind FSLIT("is") indexExpr
+   ((lExpr, t), bnds) <- packContext  bb (lift expr)     -- (b)         
+   (_, vbind)         <- mkBind FSLIT("r") lExpr
    return (bbind, vbind, bnds)
 
 -- letWrapper lExpr b ([indexbnd_i], [exprbnd_i], [pckbnd_ij])
@@ -767,7 +766,7 @@ mkDftBackpermute :: Type -> Var -> Var -> Var -> Flatten CoreBind
 mkDftBackpermute ty idx src dft = 
   do
     rhs <- mk'bpermuteDftP ty (Var idx) (Var src) (Var dft)
-    liftM snd $ mkBind (slit "dbp"#) rhs
+    liftM snd $ mkBind FSLIT("dbp") rhs
 
 -- create a dummy array with elements of the given type, which can be used as
 -- default array for the combination of the subresults of the lifted case
@@ -781,7 +780,7 @@ createDftArrayBind e  =
     let ty = parrElemTy . exprType $ expr
     len <- mk'lengthP e
     rhs <- mk'replicateP ty len err??
-    lift snd $ mkBind (slit "dft"#) rhs
+    lift snd $ mkBind FSLIT("dft") rhs
 FIXME: nicht so einfach; man kann kein "error"-Wert nehmen, denn der w"urde
   beim bpermuteDftP sofort evaluiert, aber es ist auch schwer m"oglich einen
   generischen Wert f"ur jeden beliebigen Typ zu erfinden.
@@ -809,4 +808,4 @@ showCoreExpr (Case ex b alts) =
   "Case b = " ++ (showCoreExpr ex) ++ " of \n" ++ (showAlts alts)
   where showAlts _ = ""  
 showCoreExpr (Note _ ex) = "Note n " ++ (showCoreExpr ex)
-showCoreExpr (Type t) = "Type"
\ No newline at end of file
+showCoreExpr (Type t) = "Type"