[project @ 1998-04-30 18:47:08 by sof]
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index c7d2ff4..a1c602d 100644 (file)
@@ -4,15 +4,12 @@
 \section[Specialise]{Stamping out overloading, and (optionally) polymorphism}
 
 \begin{code}
-module Specialise (
-       specProgram, 
-       idSpecVars
-    ) where
+module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
 
 import MkId            ( mkUserLocal )
-import Id              ( Id, DictVar, idType, 
+import Id              ( Id, DictVar, idType, mkTemplateLocals,
 
                          getIdSpecialisation, setIdSpecialisation, isSpecPragmaId,
 
@@ -26,22 +23,26 @@ import Type         ( Type, mkTyVarTy, splitSigmaTy, instantiateTy, isDictTy,
                          tyVarsOfType, tyVarsOfTypes, applyTys, mkForAllTys
                        )
 import TyCon           ( TyCon )
-import TyVar           ( TyVar,
+import TyVar           ( TyVar, mkTyVar, mkSysTyVar,
                          TyVarSet, mkTyVarSet, isEmptyTyVarSet, intersectTyVarSets,
                                    elementOfTyVarSet, unionTyVarSets, emptyTyVarSet,
+                                   minusTyVarSet,
                          TyVarEnv, mkTyVarEnv, delFromTyVarEnv
                        )
+import Kind            ( mkBoxedTypeKind )
 import CoreSyn
+import FreeVars                ( exprFreeVars )
 import PprCore         ()      -- Instances 
-import Name            ( NamedThing(..), getSrcLoc )
+import Name            ( NamedThing(..), getSrcLoc, mkSysLocalName, isLocallyDefined )
+import SrcLoc          ( noSrcLoc )
 import SpecEnv         ( addToSpecEnv, lookupSpecEnv, specEnvValues )
 
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs, thenUs, returnUs, getUnique, mapUs
                        )
-
+import Unique          ( mkAlphaTyVarUnique )
 import FiniteMap
-import Maybes          ( MaybeErr(..), maybeToBool )
+import Maybes          ( MaybeErr(..), maybeToBool, catMaybes )
 import Bag
 import List            ( partition )
 import Util            ( zipEqual )
@@ -710,7 +711,7 @@ specBind (NonRec bndr rhs) body_uds
 
   | isSpecPragmaId bndr
   = specExpr rhs                               `thenSM` \ (rhs', rhs_uds) ->
-    returnSM ([], rhs_uds)
+    returnSM ([], rhs_uds `plusUDs` body_uds)
 
   | otherwise
   =   -- Deal with the RHS, specialising it according
@@ -718,14 +719,20 @@ specBind (NonRec bndr rhs) body_uds
     specDefn (calls body_uds) (bndr,rhs)       `thenSM` \ ((bndr',rhs'), spec_defns, spec_uds) ->
     let
        (all_uds, (dict_binds, dump_calls)) 
-               = splitUDs [ValBinder bndr] (spec_uds `plusUDs` body_uds)
+               = splitUDs [ValBinder bndr]
+                          (body_uds `plusUDs` spec_uds)
+                       -- It's important that the `plusUDs` is this way round,
+                       -- because body_uds may bind dictionaries that are
+                       -- used in the calls passed to specDefn.  So the
+                       -- dictionary bindings in spec_uds may mention 
+                       -- dictionaries bound in body_uds.
 
         -- If we make specialisations then we Rec the whole lot together
         -- If not, leave it as a NonRec
         new_bind | null spec_defns = NonRec bndr' rhs'
                  | otherwise       = Rec ((bndr',rhs'):spec_defns)
     in
-    returnSM ( new_bind : dict_binds, all_uds )
+    returnSM ( new_bind : mkDictBinds dict_binds, all_uds )
 
 specBind (Rec pairs) body_uds
   = mapSM (specDefn (calls body_uds)) pairs    `thenSM` \ stuff ->
@@ -733,11 +740,15 @@ specBind (Rec pairs) body_uds
        (pairs', spec_defns_s, spec_uds_s) = unzip3 stuff
        spec_defns = concat spec_defns_s
        spec_uds   = plusUDList spec_uds_s
+
        (all_uds, (dict_binds, dump_calls)) 
-               = splitUDs (map (ValBinder . fst) pairs) (spec_uds `plusUDs` body_uds)
+               = splitUDs (map (ValBinder . fst) pairs)
+                          (body_uds `plusUDs` spec_uds)
+                       -- See notes for non-rec case
+
         new_bind = Rec (spec_defns ++ pairs')
     in
-    returnSM ( new_bind : dict_binds, all_uds )
+    returnSM ( new_bind : mkDictBinds dict_binds, all_uds )
     
 specDefn :: CallDetails                        -- Info on how it is used in its scope
         -> (Id, CoreExpr)              -- The thing being bound and its un-processed RHS
@@ -764,7 +775,7 @@ specDefn calls (fn, rhs)
        (spec_defns, spec_uds, spec_env_stuff) = unzip3 stuff
 
        fn'  = addIdSpecialisations fn spec_env_stuff
-       rhs' = foldr Lam (foldr Let body' dict_binds) rhs_bndrs 
+       rhs' = foldr Lam (mkDictLets dict_binds body') rhs_bndrs 
     in
     returnSM ((fn',rhs'), 
              spec_defns, 
@@ -779,10 +790,6 @@ specDefn calls (fn, rhs)
     (tyvars, theta, tau) = splitSigmaTy fn_type
     n_tyvars            = length tyvars
     n_dicts             = length theta
-    mk_spec_tys call_ts  = zipWith mk_spec_ty call_ts tyvars
-                         where
-                           mk_spec_ty (Just ty) _     = ty
-                           mk_spec_ty Nothing   tyvar = mkTyVarTy tyvar
 
     (rhs_tyvars, rhs_ids, rhs_body) = collectBinders rhs
     rhs_dicts = take n_dicts rhs_ids
@@ -794,11 +801,6 @@ specDefn calls (fn, rhs)
                        Nothing -> []
                        Just cs -> fmToList cs
 
-    -- Filter out calls for which we already have a specialisation
-    calls_to_spec        = filter spec_me calls_for_me
-    spec_me (call_ts, _) = not (maybeToBool (lookupSpecEnv id_spec_env (mk_spec_tys call_ts)))
-    id_spec_env          = getIdSpecialisation fn
-
     ----------------------------------------------------------
        -- Specialise to one particular call pattern
     spec_call :: ProtoUsageDetails          -- From the original body, captured by
@@ -817,13 +819,20 @@ specDefn calls (fn, rhs)
                --      f1 = /\ b d -> (..rhs of f..) t1 b t3 d d1 d2
                -- and the type of this binder
         let
-           spec_tyvars = [tyvar | (tyvar, Nothing) <- tyvars `zip` call_ts]
-          spec_tys    = mk_spec_tys call_ts
+         mk_spec_ty Nothing   = newTyVarSM   `thenSM` \ tyvar ->
+                                returnSM (Just tyvar, mkTyVarTy tyvar)
+         mk_spec_ty (Just ty) = returnSM (Nothing,    ty)
+        in
+        mapSM mk_spec_ty call_ts   `thenSM` \ stuff ->
+        let
+          (maybe_spec_tyvars, spec_tys) = unzip stuff
+           spec_tyvars = catMaybes maybe_spec_tyvars
           spec_rhs    = mkTyLam spec_tyvars $
                          mkGenApp rhs (map TyArg spec_tys ++ map VarArg call_ds)
           spec_id_ty  = mkForAllTys spec_tyvars (instantiateTy ty_env tau)
           ty_env      = mkTyVarEnv (zipEqual "spec_call" tyvars spec_tys)
        in
+
        newIdSM fn spec_id_ty           `thenSM` \ spec_f ->
 
 
@@ -833,8 +842,11 @@ specDefn calls (fn, rhs)
                -- dictionaries, so it's tidier to make new local variables
                -- for the lambdas in the RHS, rather than lambda-bind the
                -- dictionaries themselves.
-       mapSM (\d -> newIdSM d (idType d)) call_ds      `thenSM` \ arg_ds ->
+               --
+               -- In fact we use the standard template locals, so that the
+               -- they don't need to be "tidied" before putting in interface files
        let
+          arg_ds        = mkTemplateLocals (map idType call_ds)
           spec_env_rhs  = mkValLam arg_ds $
                           mkTyApp (Var spec_f) $
                           map mkTyVarTy spec_tyvars
@@ -868,7 +880,7 @@ type FreeDicts = IdSet
 
 data UsageDetails 
   = MkUD {
-       dict_binds :: !(Bag (DictVar, CoreExpr, TyVarSet, FreeDicts)),
+       dict_binds :: !(Bag DictBind),
                        -- Floated dictionary bindings
                        -- The order is important; 
                        -- in ds1 `union` ds2, bindings in ds2 can depend on those in ds1
@@ -878,9 +890,14 @@ data UsageDetails
        calls     :: !CallDetails
     }
 
+type DictBind = (DictVar, CoreExpr, TyVarSet, FreeDicts)
+                       -- The FreeDicts are the free dictionaries (only)
+                       -- of the RHS of the dictionary bindings
+                       -- Similarly the TyVarSet
+
 emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM }
 
-type ProtoUsageDetails = ([CoreBinding],               -- Dict bindings
+type ProtoUsageDetails = ([DictBind],
                          [(Id, [Maybe Type], [DictVar])]
                         )
 
@@ -951,11 +968,19 @@ dumpAllDictBinds (MkUD {dict_binds = dbs}) binds
   where
     add (dict,rhs,_,_) binds = NonRec dict rhs : binds
 
+mkDictBinds :: [DictBind] -> [CoreBinding]
+mkDictBinds = map (\(d,r,_,_) -> NonRec d r)
+
+mkDictLets :: [DictBind] -> CoreExpr -> CoreExpr
+mkDictLets dbs body = foldr mk body dbs
+                   where
+                     mk (d,r,_,_) e = Let (NonRec d r) e 
+
 dumpUDs :: [CoreBinder]
        -> UsageDetails -> CoreExpr
        -> (UsageDetails, CoreExpr)
 dumpUDs bndrs uds body
-  = (free_uds, foldr Let body dict_binds)
+  = (free_uds, mkDictLets dict_binds body)
   where
     (free_uds, (dict_binds, _)) = splitUDs bndrs uds
 
@@ -1001,7 +1026,7 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs,
        = (free_dbs `snocBag` db, dump_dbs, dump_idset)
 
        | otherwise     -- Dump it
-       = (free_dbs, dump_dbs `snocBag` NonRec dict rhs, 
+       = (free_dbs, dump_dbs `snocBag` db,
           dump_idset `addOneToIdSet` dict)
 \end{code}
 
@@ -1011,13 +1036,16 @@ the given UDs
 \begin{code}
 specUDs :: [(TyVar,Type)] -> [(DictVar,DictVar)] -> ProtoUsageDetails -> SpecM UsageDetails
 specUDs tv_env_list dict_env_list (dbs, calls)
-  = specDBs dict_env dbs               `thenSM` \ (dict_env', dbs') ->
+  = specDBs dict_env_list dbs          `thenSM` \ (dict_env_list', dbs') ->
+    let
+       dict_env = mkIdEnv dict_env_list'
+    in
     returnSM (MkUD { dict_binds = dbs',
-                    calls      = listToCallDetails (map (inst_call dict_env') calls)
+                    calls      = listToCallDetails (map (inst_call dict_env) calls)
     })
   where
-    tv_env   = mkTyVarEnv tv_env_list
-    dict_env = mkIdEnv dict_env_list
+    bound_tyvars = mkTyVarSet (map fst tv_env_list)
+    tv_env   = mkTyVarEnv tv_env_list  -- Doesn't change
 
     inst_call dict_env (id, tys, dicts) = (id, map inst_maybe_ty tys, 
                                               map (lookupId dict_env) dicts)
@@ -1027,14 +1055,22 @@ specUDs tv_env_list dict_env_list (dbs, calls)
 
     specDBs dict_env []
        = returnSM (dict_env, emptyBag)
-    specDBs dict_env (NonRec dict rhs : dbs)
+    specDBs dict_env ((dict, rhs, ftvs, fvs) : dbs)
        = newIdSM dict (instantiateTy tv_env (idType dict))     `thenSM` \ dict' ->
          let
-           dict_env' = addOneToIdEnv dict_env dict dict'
-           rhs'      = instantiateDictRhs tv_env dict_env rhs
+           rhs'      = foldl App (foldr Lam rhs (t_bndrs ++ d_bndrs)) (t_args ++ d_args)
+           (t_bndrs, t_args) = unzip [(TyBinder tv, TyArg ty)  | (tv,ty) <- tv_env_list,
+                                                                  tv `elementOfTyVarSet` ftvs]
+           (d_bndrs, d_args) = unzip [(ValBinder d, VarArg d') | (d,d')  <- dict_env,
+                                                                  d `elementOfIdSet` fvs]
+           dict_env' = (dict,dict') : dict_env
+           ftvs' = tyVarsOfTypes [ty | TyArg ty <- t_args] `unionTyVarSets`
+                   (ftvs `minusTyVarSet` bound_tyvars)
+           fvs'  = mkIdSet [d | VarArg d <- d_args] `unionIdSets`
+                   (fvs `minusIdSet` mkIdSet [d | ValBinder d <- d_bndrs])
          in
          specDBs dict_env' dbs         `thenSM` \ (dict_env'', dbs') ->
-         returnSM ( dict_env'', mkDB dict' rhs' `consBag` dbs' )
+         returnSM ( dict_env'', (dict', rhs', ftvs', fvs') `consBag` dbs' )
 \end{code}
 
 %************************************************************************
@@ -1049,43 +1085,8 @@ lookupId env id = case lookupIdEnv env id of
                        Nothing  -> id
                        Just id' -> id'
 
-instantiateDictRhs :: TyVarEnv Type -> IdEnv Id -> CoreExpr -> CoreExpr
-       -- Cheapo function for simple RHSs
-instantiateDictRhs ty_env id_env rhs
-  = go rhs
-  where
-    go_arg (VarArg a) = VarArg (lookupId id_env a)
-    go_arg (TyArg t)  = TyArg (instantiateTy ty_env t)
-
-    go (App e1 arg)   = App (go e1) (go_arg arg)
-    go (Var v)       = Var (lookupId id_env v)
-    go (Lit l)       = Lit l
-    go (Con con args) = Con con (map go_arg args)
-    go (Note n e)     = Note (go_note n) (go e)
-    go (Case e alts)  = Case (go e) alts               -- See comment below re alts
-    go other         = pprPanic "instantiateDictRhs" (ppr rhs)
-
-    go_note (Coerce t1 t2) = Coerce (instantiateTy ty_env t1) (instantiateTy ty_env t2)
-    go_note note          = note
-
 dictRhsFVs :: CoreExpr -> IdSet
-       -- Cheapo function for simple RHSs
-dictRhsFVs e
-  = go e
-  where
-    go (App e1 (VarArg a)) = go e1 `addOneToIdSet` a
-    go (App e1 (TyArg t))  = go e1
-    go (Var v)            = unitIdSet v
-    go (Lit l)            = emptyIdSet
-    go (Con _ args)        = mkIdSet [id | VarArg id <- args]
-    go (Note _ e)         = go e
-
-    go (Case e _)         = go e       -- Claim: no free dictionaries in the alternatives
-                                       -- These case expressions are of the form
-                                       --   case d of { D a b c -> b }
-
-    go other              = pprPanic "dictRhsFVs" (ppr e)
-
+dictRhsFVs e = exprFreeVars isLocallyDefined e
 
 addIdSpecialisations id spec_stuff
   = (if not (null errs) then
@@ -1101,22 +1102,6 @@ addIdSpecialisations id spec_stuff
                Succeeded spec_env' -> (spec_env', errs)
                Failed err          -> (spec_env, err:errs)
 
--- Given an Id, isSpecVars returns all its specialisations.
--- We extract these from its SpecEnv.
--- This is used by the occurrence analyser and free-var finder;
--- we regard an Id's specialisations as free in the Id's definition.
-
-idSpecVars :: Id -> [Id]
-idSpecVars id 
-  = map get_spec (specEnvValues (getIdSpecialisation id))
-  where
-    -- get_spec is another cheapo function like dictRhsFVs
-    -- It knows what these specialisation temlates look like,
-    -- and just goes for the jugular
-    get_spec (App f _) = get_spec f
-    get_spec (Lam _ b) = get_spec b
-    get_spec (Var v)   = v
-
 ----------------------------------------
 type SpecM a = UniqSM a
 
@@ -1138,6 +1123,10 @@ newIdSM old_id new_ty
                          new_ty
                          (getSrcLoc old_id)
     )
+
+newTyVarSM
+  = getUnique          `thenSM` \ uniq ->
+    returnSM (mkSysTyVar uniq mkBoxedTypeKind)
 \end{code}