remove empty dir
[ghc-hetmet.git] / ghc / compiler / ndpFlatten / Flattening.hs
index 4733bc4..18daaa6 100644 (file)
@@ -52,52 +52,49 @@ module Flattening (
   flatten, flattenExpr, 
 ) where 
 
--- standard
-import Monad        (liftM, foldM)
+#include "HsVersions.h"
+
+-- friends
+import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
+                    isLit, mkPArrTy, mkTuple, isSimpleExpr, substIdEnv)
+import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
+                    liftVar, liftConst, intersectWithContext, mk'fst,
+                    mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
+                    mk'indexOfP,mk'eq,mk'neq) 
 
 -- GHC
-import CmdLineOpts  (opt_Flatten)
+import TcType      ( tcIsForAllTy, tcView )
+import TypeRep     ( Type(..) )
+import StaticFlags  (opt_Flatten)
 import Panic        (panic)
 import ErrUtils     (dumpIfSet_dyn)
-import UniqSupply   (UniqSupply, mkSplitUniqSupply)
-import CmdLineOpts  (DynFlag(..), DynFlags)
+import UniqSupply   (mkSplitUniqSupply)
+import DynFlags  (DynFlag(..))
 import Literal      (Literal, literalType)
-import Var         (Var(..),TyVar)
+import Var         (Var(..), idType, isTyVar)
+import Id          (setIdType)
 import DataCon     (DataCon, dataConTag)
-import TypeRep      (Type(..))
-import Type         (isTypeKind)
-import HscTypes            (HomeSymbolTable, PersistentCompilerState, ModDetails(..))
+import HscTypes            ( ModGuts(..), ModGuts, HscEnv(..), hscEPS )
 import CoreFVs     (exprFreeVars)
 import CoreSyn     (Expr(..), Bind(..), Alt(..), AltCon(..), Note(..),
-                    CoreBndr, CoreExpr, CoreBind, CoreAlt, mkLams, mkLets,
+                    CoreBndr, CoreExpr, CoreBind, mkLams, mkLets,
                     mkApps, mkIntLitInt)  
 import PprCore      (pprCoreExpr)
 import CoreLint            (showPass, endPass)
 
 import CoreUtils    (exprType, applyTypeToArg, mkPiType)
-import VarEnv       (IdEnv, mkVarEnv, zipVarEnv, extendVarEnv)
+import VarEnv       (zipVarEnv)
 import TysWiredIn   (mkTupleTy)
 import BasicTypes   (Boxity(..))
-import Outputable   (showSDoc, Outputable(..))
+import Outputable
+import FastString
 
 
--- friends
-import NDPCoreUtils (tupleTyArgs, funTyArgs, parrElemTy, isDefault,
-                    isLit, mkPArrTy, mkTuple, isSimpleExpr, boolTy, substIdEnv)
-import FlattenMonad (Flatten, runFlatten, mkBind, extendContext, packContext,
-                    liftVar, liftConst, intersectWithContext, mk'fst,
-                    mk'lengthP, mk'replicateP, mk'mapP, mk'bpermuteDftP,
-                    mk'indexOfP,mk'eq,mk'neq) 
-
 -- FIXME: fro debugging - remove this
-import IOExts    (trace)
-
-
-#include "HsVersions.h"
-{-# INLINE slit #-}
-slit x = FastString.mkFastCharString# x
--- FIXME: SLIT() doesn't work for some strange reason
+import TRACE    (trace)
 
+-- standard
+import Monad        (liftM, foldM)
 
 -- toplevel transformation
 -- -----------------------
@@ -105,15 +102,16 @@ slit x = FastString.mkFastCharString# x
 -- entry point to the flattening transformation for the compiler driver when
 -- compiling a complete module (EXPORTED) 
 --
-flatten :: DynFlags 
-       -> PersistentCompilerState 
-       -> HomeSymbolTable
-       -> ModDetails                   -- the module to be flattened
-       -> IO ModDetails
-flatten dflags pcs hst modDetails@(ModDetails {md_binds = binds}) 
-  | not opt_Flatten = return modDetails -- skip without -fflatten
+flatten :: HscEnv
+       -> ModGuts
+       -> IO ModGuts
+flatten hsc_env mod_impl@(ModGuts {mg_binds = binds}) 
+  | not opt_Flatten = return mod_impl -- skip without -fflatten
   | otherwise       =
   do
+    let dflags = hsc_dflags hsc_env
+
+    eps <- hscEPS hsc_env
     us <- mkSplitUniqSupply 'l'                -- 'l' as in fLattening
     --
     -- announce vectorisation
@@ -122,26 +120,27 @@ flatten dflags pcs hst modDetails@(ModDetails {md_binds = binds})
     --
     -- vectorise all toplevel bindings
     --
-    let binds' = runFlatten pcs hst us $ vectoriseTopLevelBinds binds
+    let binds' = runFlatten hsc_env eps us $ vectoriseTopLevelBinds binds
     --
     -- and dump the result if requested
     --
     endPass dflags "Flattening [first phase: vectorisation]" 
            Opt_D_dump_vect binds'
-    return $ modDetails {md_binds = binds'}
+    return $ mod_impl {mg_binds = binds'}
 
 -- entry point to the flattening transformation for the compiler driver when
 -- compiling a single expression in interactive mode (EXPORTED) 
 --
-flattenExpr :: DynFlags 
-           -> PersistentCompilerState 
-           -> HomeSymbolTable 
+flattenExpr :: HscEnv
            -> CoreExpr                 -- the expression to be flattened
            -> IO CoreExpr
-flattenExpr dflags pcs hst expr
+flattenExpr hsc_env expr
   | not opt_Flatten = return expr       -- skip without -fflatten
   | otherwise       =
   do
+    let dflags = hsc_dflags hsc_env
+    eps <- hscEPS hsc_env
+
     us <- mkSplitUniqSupply 'l'                -- 'l' as in fLattening
     --
     -- announce vectorisation
@@ -150,7 +149,7 @@ flattenExpr dflags pcs hst expr
     --
     -- vectorise the expression
     --
-    let expr' = fst . runFlatten pcs hst us $ vectorise expr
+    let expr' = fst . runFlatten hsc_env eps us $ vectorise expr
     --
     -- and dump the result if requested
     --
@@ -194,7 +193,7 @@ vectoriseBind (Rec bindings)   =
     vectoriseOne (b, expr) = 
       do
        (vexpr, ty) <- vectorise expr
-       return (b{varType = ty}, vexpr)
+       return (setIdType b ty, vexpr)
 
 
 -- Searches for function definitions and creates a lifted version for 
@@ -219,9 +218,9 @@ vectoriseBind (Rec bindings)   =
 vectorise:: CoreExpr -> Flatten (CoreExpr, Type)
 vectorise (Var id)  =  
   do 
-    let varTy  = varType id
+    let varTy  = idType id
     let vecTy  = vectoriseTy varTy
-    return ((Var id{varType = vecTy}), vecTy)
+    return (Var (setIdType id vecTy), vecTy)
 
 vectorise (Lit lit) =  
   return ((Lit lit), literalType lit) 
@@ -236,7 +235,7 @@ vectorise  (App (Lam b expr) arg) =
   do
     (varg, argTy)    <- vectorise arg
     (vexpr, vexprTy) <- vectorise expr
-    let vb            = b{varType = argTy} 
+    let vb            = setIdType b argTy
     return ((App (Lam vb  vexpr) varg), 
             applyTypeToArg (mkPiType vb vexprTy) varg)
 
@@ -248,7 +247,7 @@ vectorise (App expr arg) =
     (vexpr, vexprTy) <-  vectorise expr
     (varg,  vargTy)  <-  vectorise arg
 
-    if (isPolyType vexprTy)
+    if (tcIsForAllTy vexprTy)
       then do
         let resTy =  applyTypeToArg vexprTy varg
         return (App vexpr varg, resTy)
@@ -258,23 +257,16 @@ vectorise (App expr arg) =
         let resTy    = applyTypeToArg t1 varg   
         return  ((App vexpr' varg), resTy)  -- apply the first component of
                                             -- the vectorized function
-  where
-    isPolyType t =  
-        (case t  of
-           (ForAllTy _ _)  -> True
-           (NoteTy _ nt)   -> isPolyType nt
-           _               -> False)
-    
 
 vectorise  e@(Lam b expr)
-  | isTypeKind (varType b) = 
-      do
+  | isTyVar b
+  =  do
         (vexpr, vexprTy) <- vectorise expr          -- don't vectorise 'b'!
         return ((Lam b vexpr), mkPiType b vexprTy)
   | otherwise =
      do          
        (vexpr, vexprTy)  <- vectorise expr
-       let vb             = b{varType = vectoriseTy (varType b)}
+       let vb             = setIdType b (vectoriseTy (idType b))
        let ve             =  Lam  vb  vexpr 
        (lexpr, lexprTy)  <- lift e
        let veTy = mkPiType vb vexprTy  
@@ -287,11 +279,12 @@ vectorise (Let bind body) =
     (vbody, vbodyTy) <- vectorise body
     return ((Let vbind vbody), vbodyTy)
 
-vectorise (Case expr b alts) =
+vectorise (Case expr b ty alts) =
   do 
     (vexpr, vexprTy) <- vectorise expr
     valts <- mapM vectorise' alts
-    return (Case vexpr b{varType = vexprTy} (map fst valts), snd (head valts))
+    let res_ty = snd (head valts)
+    return (Case vexpr (setIdType b vexprTy) res_ty (map fst valts), res_ty)
   where vectorise' (con, bs, expr) = 
           do 
             (vexpr, vexprTy) <- vectorise expr
@@ -318,6 +311,10 @@ myShowTy (TyConApp _ t) =
 -}
 
 vectoriseTy :: Type -> Type 
+vectoriseTy ty | Just ty' <- tcView ty = vectoriseTy ty'
+       -- Look through notes and synonyms
+       -- NB: This will discard notes and synonyms, of course
+       -- ToDo: retain somehow?
 vectoriseTy t@(TyVarTy v)      =  t
 vectoriseTy t@(AppTy t1 t2)    = 
   AppTy (vectoriseTy t1) (vectoriseTy t2)
@@ -328,8 +325,6 @@ vectoriseTy t@(FunTy t1 t2)    =
                      (liftTy t)]
 vectoriseTy  t@(ForAllTy v ty)  = 
   ForAllTy v (vectoriseTy  ty)
-vectoriseTy t@(NoteTy note ty) =  -- FIXME: is the note still valid after
-  NoteTy note  (vectoriseTy ty)   --   this or should we just throw it away
 vectoriseTy  t =  t
 
 
@@ -337,9 +332,9 @@ vectoriseTy  t =  t
 --    on the *top level* (is this sufficient???)
 
 liftTy:: Type -> Type
+liftTy ty | Just ty' <- tcView ty = liftTy ty'
 liftTy (FunTy t1 t2)   = FunTy (liftTy t1) (liftTy t2)
 liftTy (ForAllTy tv t) = ForAllTy tv (liftTy t)
-liftTy (NoteTy n t)    = NoteTy n $ liftTy t
 liftTy  t              = mkPArrTy t
 
 
@@ -355,7 +350,7 @@ liftTy  t              = mkPArrTy t
 --  lift type, don't change name (incl unique) nor IdInfo. IdInfo looks ok,
 --  but I'm not entirely sure about some fields (e.g., strictness info)
 liftBinderType:: CoreBndr ->  Flatten CoreBndr
-liftBinderType bndr = return $  bndr {varType = liftTy (varType bndr)}
+liftBinderType bndr = return $  setIdType bndr (liftTy (idType bndr))
 
 -- lift: lifts an expression (a -> [:a:])
 -- If the expression is a simple expression, it is treated like a constant
@@ -366,7 +361,7 @@ lift:: CoreExpr -> Flatten (CoreExpr, Type)
 lift cExpr@(Var id)    = 
   do
     lVar@(Var lId) <- liftVar id
-    return (lVar, varType lId)
+    return (lVar, idType lId)
 
 lift cExpr@(Lit lit)   = 
   do
@@ -376,7 +371,7 @@ lift cExpr@(Lit lit)   =
 
 lift (Lam b expr)
   | isSimpleExpr expr      =  liftSimpleFun b expr
-  | isTypeKind (varType b) = 
+  | isTyVar b = 
     do
       (lexpr, lexprTy) <- lift expr  -- don't lift b!
       return (Lam b lexpr, mkPiType b lexprTy)
@@ -443,7 +438,8 @@ lift (Let (Rec binds) expr2) =
 --        otherwise (a) compute index vector for simpleAlts (for def permute
 --                      later on
 --                  (b) 
-lift cExpr@(Case expr b alts)  =
+-- gaw 2004 FIX? 
+lift cExpr@(Case expr b _ alts)  =
   do  
     (lExpr, _) <- lift expr
     lb    <- liftBinderType  b     -- lift alt-expression
@@ -504,11 +500,11 @@ liftSingleDataCon:: CoreBndr -> DataCon -> [CoreBndr] -> CoreExpr ->
 liftSingleDataCon b dcon bnds expr =
   do 
     let dconId           = dataConTag dcon
-    indexExpr           <- mkIndexOfExprDCon (varType b)  b dconId
-    (b', bbind)         <- mkBind (slit "is"#) indexExpr
+    indexExpr           <- mkIndexOfExprDCon (idType b)  b dconId
+    (bb, bbind)         <- mkBind FSLIT("is") indexExpr
     lbnds               <- mapM liftBinderType bnds
-    ((lExpr, _), bnds') <- packContext  b' (extendContext lbnds (lift expr))
-    (_, vbind)          <- mkBind (slit "r"#) lExpr
+    ((lExpr, _), bnds') <- packContext  bb (extendContext lbnds (lift expr))
+    (_, vbind)          <- mkBind FSLIT("r") lExpr
     return (bbind, vbind, bnds')
 
 -- FIXME: clean this up. the datacon and the literal case are so
@@ -520,10 +516,10 @@ liftCaseDataConDefault:: CoreBndr -> (Alt CoreBndr) ->  [Alt CoreBndr]
 liftCaseDataConDefault b (_, _, def) alts =
   do
     let dconIds        = map (\(DataAlt d, _, _) -> dataConTag d) alts
-    indexExpr         <- mkIndexOfExprDConDft (varType b) b dconIds
-    (b', bbind)       <- mkBind (slit "is"#) indexExpr
-    ((lDef, _), bnds) <- packContext  b' (lift def)     
-    (_, vbind)        <- mkBind (slit "r"#) lDef
+    indexExpr         <- mkIndexOfExprDConDft (idType b) b dconIds
+    (bb, bbind)       <- mkBind FSLIT("is") indexExpr
+    ((lDef, _), bnds) <- packContext  bb (lift def)     
+    (_, vbind)        <- mkBind FSLIT("r") lDef
     return (bbind, vbind, bnds)
 
 -- liftCaseLit: checks if we have a default case and handles it 
@@ -551,10 +547,10 @@ liftCaseLitDefault:: CoreBndr -> (Alt CoreBndr) ->  [Alt CoreBndr]
 liftCaseLitDefault b (_, _, def) alts =
   do
     let lits           = map (\(LitAlt l, _, _) -> l) alts
-    indexExpr         <- mkIndexOfExprDft (varType b) b lits
-    (b', bbind)       <- mkBind (slit "is"#) indexExpr
-    ((lDef, _), bnds) <- packContext  b' (lift def)     
-    (_, vbind)        <- mkBind (slit "r"#) lDef
+    indexExpr         <- mkIndexOfExprDft (idType b) b lits
+    (bb, bbind)       <- mkBind FSLIT("is") indexExpr
+    ((lDef, _), bnds) <- packContext  bb (lift def)     
+    (_, vbind)        <- mkBind FSLIT("r") lDef
     return (bbind, vbind, bnds)
 
 -- FIXME: 
@@ -590,10 +586,10 @@ liftSingleCaseLit:: CoreBndr -> Literal -> CoreExpr  ->
   Flatten (CoreBind, CoreBind, [CoreBind])
 liftSingleCaseLit b lit expr =
  do 
-   indexExpr          <- mkIndexOfExpr (varType b) b lit -- (a)
-   (b', bbind)        <- mkBind (slit "is"#) indexExpr
-   ((lExpr, t), bnds) <- packContext  b' (lift expr)     -- (b)         
-   (_, vbind)         <- mkBind (slit "r"#) lExpr
+   indexExpr          <- mkIndexOfExpr (idType b) b lit -- (a)
+   (bb, bbind)        <- mkBind FSLIT("is") indexExpr
+   ((lExpr, t), bnds) <- packContext  bb (lift expr)     -- (b)         
+   (_, vbind)         <- mkBind FSLIT("r") lExpr
    return (bbind, vbind, bnds)
 
 -- letWrapper lExpr b ([indexbnd_i], [exprbnd_i], [pckbnd_ij])
@@ -647,14 +643,13 @@ dftbpBinders indexBnds exprBnds =
        let iVar = getVarOfBind i
        let eVar = getVarOfBind e
        let cVar = getVarOfBind cBind
-        let ty   = varType eVar
+        let ty   = idType eVar
        newBnd  <- mkDftBackpermute ty iVar eVar cVar
        ((fBnd, restBnds), _) <- dftbpBinders' is es newBnd
        return ((fBnd, (newBnd:restBnds)), liftTy ty)
 
     dftbpBinders'  _ _ _ = 
-      panic "Flattening.dftbpBinders: index and expression binder lists \ 
-           \have different length!"
+      panic "Flattening.dftbpBinders: index and expression binder lists have different length!"
 
 getExprOfBind:: CoreBind -> CoreExpr
 getExprOfBind (NonRec _ expr) = expr
@@ -678,7 +673,7 @@ liftSimpleFun b expr =
   do
     bndVars <- collectBoundVars expr
     let bndVars'     = b:bndVars
-        bndVarsTuple = mkTuple (map varType bndVars') (map Var bndVars')
+        bndVarsTuple = mkTuple (map idType bndVars') (map Var bndVars')
        lamExpr      = mkLams (b:bndVars) expr     -- FIXME: should be tuple
                                                    -- here 
     let (t1, t2)     = funTyArgs . exprType $ lamExpr
@@ -699,11 +694,11 @@ collectBoundVars  expr =
 --   indexOf (mapP (\x -> x == lit) b) b
 --
 mkIndexOfExpr:: Type -> CoreBndr -> Literal -> Flatten CoreExpr
-mkIndexOfExpr  varType b lit =
+mkIndexOfExpr  idType b lit =
   do 
-    eqExpr        <- mk'eq  varType (Var b) (Lit lit)
+    eqExpr        <- mk'eq idType (Var b) (Lit lit)
     let lambdaExpr = (Lam b eqExpr)
-    mk'indexOfP varType  lambdaExpr (Var b)
+    mk'indexOfP idType  lambdaExpr (Var b)
 
 -- there is FlattenMonad.mk'indexOfP as well as
 -- CoreSyn.mkApps and CoreSyn.mkLam, all of which should help here
@@ -717,12 +712,12 @@ mkIndexOfExpr  varType b lit =
 -- indexOfP (\x -> x == dconId) b)
 --
 mkIndexOfExprDCon::Type -> CoreBndr -> Int -> Flatten CoreExpr
-mkIndexOfExprDCon  varType b dId = 
+mkIndexOfExprDCon  idType b dId = 
   do 
     let intExpr    = mkIntLitInt dId
-    eqExpr        <- mk'eq  varType (Var b) intExpr
+    eqExpr        <- mk'eq  idType (Var b) intExpr
     let lambdaExpr = (Lam b intExpr)
-    mk'indexOfP varType lambdaExpr (Var b) 
+    mk'indexOfP idType lambdaExpr (Var b) 
 
   
 
@@ -735,28 +730,28 @@ mkIndexOfExprDCon  varType b dId =
 -- indexOfP (\x -> x != dconId_1 && ....) b)
 --
 mkIndexOfExprDConDft:: Type -> CoreBndr -> [Int] -> Flatten CoreExpr
-mkIndexOfExprDConDft varType b dId  = 
+mkIndexOfExprDConDft idType b dId  = 
   do 
     let intExprs   = map mkIntLitInt dId
-    bExpr         <- foldM (mk'neq varType) (head intExprs) (tail intExprs)
+    bExpr         <- foldM (mk'neq idType) (head intExprs) (tail intExprs)
     let lambdaExpr = (Lam b bExpr)
-    mk'indexOfP varType (Var b) bExpr
+    mk'indexOfP idType (Var b) bExpr
   
 
 -- mkIndexOfExprDef b [lit1, lit2,...] ->
 --   indexOf (\x -> not (x == lit1 || x == lit2 ....) b
 mkIndexOfExprDft:: Type -> CoreBndr -> [Literal] -> Flatten CoreExpr
-mkIndexOfExprDft varType b lits = 
+mkIndexOfExprDft idType b lits = 
   do 
     let litExprs   = map (\l-> Lit l)  lits
-    bExpr         <- foldM (mk'neq varType) (head litExprs) (tail litExprs)
+    bExpr         <- foldM (mk'neq idType) (head litExprs) (tail litExprs)
     let lambdaExpr = (Lam b bExpr)
-    mk'indexOfP varType bExpr (Var b) 
+    mk'indexOfP idType bExpr (Var b) 
 
 
 -- create a back-permute binder
 --
--- * `mkDftBackpermute ty indexArrayVar srcArrayVar dftArrayVar' creates a
+--  * `mkDftBackpermute ty indexArrayVar srcArrayVar dftArrayVar' creates a
 --   Core binding of the form
 --
 --     x = bpermuteDftP indexArrayVar srcArrayVar dftArrayVar
@@ -767,7 +762,7 @@ mkDftBackpermute :: Type -> Var -> Var -> Var -> Flatten CoreBind
 mkDftBackpermute ty idx src dft = 
   do
     rhs <- mk'bpermuteDftP ty (Var idx) (Var src) (Var dft)
-    liftM snd $ mkBind (slit "dbp"#) rhs
+    liftM snd $ mkBind FSLIT("dbp") rhs
 
 -- create a dummy array with elements of the given type, which can be used as
 -- default array for the combination of the subresults of the lifted case
@@ -781,7 +776,7 @@ createDftArrayBind e  =
     let ty = parrElemTy . exprType $ expr
     len <- mk'lengthP e
     rhs <- mk'replicateP ty len err??
-    lift snd $ mkBind (slit "dft"#) rhs
+    lift snd $ mkBind FSLIT("dft") rhs
 FIXME: nicht so einfach; man kann kein "error"-Wert nehmen, denn der w"urde
   beim bpermuteDftP sofort evaluiert, aber es ist auch schwer m"oglich einen
   generischen Wert f"ur jeden beliebigen Typ zu erfinden.
@@ -805,8 +800,9 @@ showCoreExpr (Let bnds expr) =
   where showBinds (NonRec b e) = showBind (b,e)
         showBinds (Rec bnds)   = concat (map showBind bnds)
         showBind (b,e) = "  b = " ++ (showCoreExpr e)++ "\n"
-showCoreExpr (Case ex b alts) =
+-- gaw 2004 FIX?
+showCoreExpr (Case ex b ty alts) =
   "Case b = " ++ (showCoreExpr ex) ++ " of \n" ++ (showAlts alts)
   where showAlts _ = ""  
 showCoreExpr (Note _ ex) = "Note n " ++ (showCoreExpr ex)
-showCoreExpr (Type t) = "Type"
\ No newline at end of file
+showCoreExpr (Type t) = "Type"