[project @ 2005-02-04 17:24:01 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 006e06d..980db08 100644 (file)
@@ -12,28 +12,28 @@ import CmdLineOpts  ( DynFlags, DynFlag(..) )
 import Id              ( Id, idName, idType, mkUserLocal ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, isClassPred,
-                         mkForAllTys, tcCmpType
+                         tcCmpType, isUnLiftedType
                        )
-import Subst           ( Subst, mkSubst, substTy, mkSubst, extendSubstList, mkInScopeSet,
-                         simplBndr, simplBndrs, 
-                         substAndCloneId, substAndCloneIds, substAndCloneRecIds,
-                         lookupIdSubst, substInScope
+import CoreSubst       ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
+                         substBndr, substBndrs, substTy, substInScope,
+                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
                        ) 
 import Var             ( zapSpecPragmaId )
 import VarSet
 import VarEnv
 import CoreSyn
-import CoreUtils       ( applyTypeToArgs )
+import CoreUtils       ( applyTypeToArgs, mkPiTypes )
 import CoreFVs         ( exprFreeVars, exprsFreeVars )
 import CoreTidy                ( pprTidyIdRules )
 import CoreLint                ( showPass, endPass )
-import Rules           ( addIdSpecialisations, lookupRule )
+import Rules           ( addIdSpecialisations, lookupRule, emptyRuleBase )
 
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
                          getUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
+import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
 import Maybes          ( catMaybes, maybeToBool )
 import ErrUtils                ( dumpIfSet_dyn )
@@ -595,7 +595,7 @@ specProgram dflags us binds
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst      = mkSubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) emptySubstEnv
+    top_subst      = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
 
     go []          = returnSM ([], emptyUDs)
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
@@ -611,9 +611,7 @@ specProgram dflags us binds
 
 \begin{code}
 specVar :: Subst -> Id -> CoreExpr
-specVar subst v = case lookupIdSubst subst v of
-                       DoneEx e   -> e
-                       DoneId v _ -> Var v
+specVar subst v = lookupIdSubst subst v
 
 specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 -- We carry a substitution down:
@@ -654,16 +652,16 @@ specExpr subst e@(Lam _ _)
     returnSM (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
-    (subst', bndrs') = simplBndrs subst bndrs
+    (subst', bndrs') = substBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
-specExpr subst (Case scrut case_bndr alts)
-  = specExpr subst scrut                       `thenSM` \ (scrut', uds_scrut) ->
+specExpr subst (Case scrut case_bndr ty alts)
+  = specExpr subst scrut               `thenSM` \ (scrut', uds_scrut) ->
     mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
-    returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
+    returnSM (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
   where
-    (subst_alt, case_bndr') = simplBndr subst case_bndr
+    (subst_alt, case_bndr') = substBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
@@ -673,7 +671,7 @@ specExpr subst (Case scrut case_bndr alts)
          in
          returnSM ((con, args', rhs''), uds')
        where
-         (subst_rhs, args') = simplBndrs subst_alt args
+         (subst_rhs, args') = substBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
 specExpr subst (Let bind body)
@@ -824,10 +822,10 @@ specDefn subst calls (fn, rhs)
     n_tyvars          = length tyvars
     n_dicts           = length theta
 
+    (rhs_tyvars, rhs_ids, rhs_body) 
+       = collectTyAndValBinders (dropInline rhs)
        -- It's important that we "see past" any INLINE pragma
        -- else we'll fail to specialise an INLINE thing
-    (inline_me, rhs')              = dropInline rhs
-    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs'
 
     rhs_dicts = take n_dicts rhs_ids
     rhs_bndrs = rhs_tyvars ++ rhs_dicts
@@ -871,17 +869,22 @@ specDefn subst calls (fn, rhs)
                       where
                         mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
                         mk_ty_arg rhs_tyvar (Just ty) = Type ty
-          rhs_subst  = extendSubstList subst spec_tyvars [DoneTy ty | Just ty <- call_ts]
+          rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts])
        in
        cloneBinders rhs_subst rhs_dicts                `thenSM` \ (rhs_subst', rhs_dicts') ->
        let
           inst_args = ty_args ++ map Var rhs_dicts'
 
                -- Figure out the type of the specialised function
-          spec_id_ty = mkForAllTys poly_tyvars (applyTypeToArgs rhs fn_type inst_args)
+          body_ty = applyTypeToArgs rhs fn_type inst_args
+          (lam_args, app_args)                 -- Add a dummy argument if body_ty is unlifted
+               | isUnLiftedType body_ty        -- C.f. WwLib.mkWorkerArgs
+               = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
+               | otherwise = (poly_tyvars, poly_tyvars)
+          spec_id_ty = mkPiTypes lam_args body_ty
        in
        newIdSM fn spec_id_ty                           `thenSM` \ spec_f ->
-       specExpr rhs_subst' (mkLams poly_tyvars body)   `thenSM` \ (spec_rhs, rhs_uds) ->       
+       specExpr rhs_subst' (mkLams lam_args body)      `thenSM` \ (spec_rhs, rhs_uds) ->       
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
@@ -889,7 +892,7 @@ specDefn subst calls (fn, rhs)
                                AlwaysActive
                                (poly_tyvars ++ rhs_dicts')
                                inst_args 
-                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+                               (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
@@ -910,9 +913,9 @@ specDefn subst calls (fn, rhs)
         | not (equalLength xs ys) = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
         | otherwise               = zipEqual doc xs ys
 
-dropInline :: CoreExpr -> (Bool, CoreExpr) 
-dropInline (Note InlineMe rhs) = (True, rhs)
-dropInline rhs                = (False, rhs)
+dropInline :: CoreExpr -> CoreExpr
+dropInline (Note InlineMe rhs) = rhs
+dropInline rhs                = rhs
 \end{code}
 
 %************************************************************************
@@ -1007,7 +1010,7 @@ mkCallUDs subst f args
        -- *don't* say what the value of the implicit param is!
   || not (spec_tys `lengthIs` n_tyvars)
   || not ( dicts   `lengthIs` n_dicts)
-  || maybeToBool (lookupRule (\act -> True) (substInScope subst) f args)
+  || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args)
        -- There's already a rule covering this call.  A typical case
        -- is where there's an explicit user-provided rule.  Then
        -- we don't want to create a specialised version 
@@ -1105,7 +1108,7 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs,
     dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
        | dump_idset `intersectsVarSet` fvs     -- Dump it
        = (free_dbs, dump_dbs `snocBag` db,
-          dump_idset `unionVarSet` mkVarSet (bindersOf bind))
+          extendVarSetList dump_idset (bindersOf bind))
 
        | otherwise     -- Don't dump it
        = (free_dbs `snocBag` db, dump_dbs, dump_idset)
@@ -1138,20 +1141,20 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 cloneBindSM subst (NonRec bndr rhs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndr') = substAndCloneId subst us bndr
+       (subst', bndr') = cloneIdBndr subst us bndr
     in
     returnUs (subst, subst', NonRec bndr' rhs)
 
 cloneBindSM subst (Rec pairs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndrs') = substAndCloneRecIds subst us (map fst pairs)
+       (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
     in
     returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
 cloneBinders subst bndrs
   = getUs      `thenUs` \ us ->
-    returnUs (substAndCloneIds subst us bndrs)
+    returnUs (cloneIdBndrs subst us bndrs)
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->