[project @ 2004-12-24 16:14:36 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 1d172e9..980db08 100644 (file)
@@ -12,28 +12,28 @@ import CmdLineOpts  ( DynFlags, DynFlag(..) )
 import Id              ( Id, idName, idType, mkUserLocal ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, isClassPred,
-                         mkForAllTys, tcCmpType
+                         tcCmpType, isUnLiftedType
                        )
-import Subst           ( Subst, mkSubst, substTy, mkSubst, extendSubstList, mkInScopeSet,
-                         simplBndr, simplBndrs, 
-                         substAndCloneId, substAndCloneIds, substAndCloneRecIds,
-                         lookupIdSubst, substInScope
+import CoreSubst       ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
+                         substBndr, substBndrs, substTy, substInScope,
+                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
                        ) 
 import Var             ( zapSpecPragmaId )
 import VarSet
 import VarEnv
 import CoreSyn
-import CoreUtils       ( applyTypeToArgs )
+import CoreUtils       ( applyTypeToArgs, mkPiTypes )
 import CoreFVs         ( exprFreeVars, exprsFreeVars )
 import CoreTidy                ( pprTidyIdRules )
 import CoreLint                ( showPass, endPass )
-import Rules           ( addIdSpecialisations, lookupRule )
+import Rules           ( addIdSpecialisations, lookupRule, emptyRuleBase )
 
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
                          getUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
+import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
 import Maybes          ( catMaybes, maybeToBool )
 import ErrUtils                ( dumpIfSet_dyn )
@@ -595,7 +595,7 @@ specProgram dflags us binds
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst      = mkSubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) emptySubstEnv
+    top_subst      = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
 
     go []          = returnSM ([], emptyUDs)
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
@@ -611,9 +611,7 @@ specProgram dflags us binds
 
 \begin{code}
 specVar :: Subst -> Id -> CoreExpr
-specVar subst v = case lookupIdSubst subst v of
-                       DoneEx e   -> e
-                       DoneId v _ -> Var v
+specVar subst v = lookupIdSubst subst v
 
 specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 -- We carry a substitution down:
@@ -654,16 +652,16 @@ specExpr subst e@(Lam _ _)
     returnSM (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
-    (subst', bndrs') = simplBndrs subst bndrs
+    (subst', bndrs') = substBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
-specExpr subst (Case scrut case_bndr alts)
-  = specExpr subst scrut                       `thenSM` \ (scrut', uds_scrut) ->
+specExpr subst (Case scrut case_bndr ty alts)
+  = specExpr subst scrut               `thenSM` \ (scrut', uds_scrut) ->
     mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
-    returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
+    returnSM (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
   where
-    (subst_alt, case_bndr') = simplBndr subst case_bndr
+    (subst_alt, case_bndr') = substBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
@@ -673,7 +671,7 @@ specExpr subst (Case scrut case_bndr alts)
          in
          returnSM ((con, args', rhs''), uds')
        where
-         (subst_rhs, args') = simplBndrs subst_alt args
+         (subst_rhs, args') = substBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
 specExpr subst (Let bind body)
@@ -871,17 +869,22 @@ specDefn subst calls (fn, rhs)
                       where
                         mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
                         mk_ty_arg rhs_tyvar (Just ty) = Type ty
-          rhs_subst  = extendSubstList subst spec_tyvars [DoneTy ty | Just ty <- call_ts]
+          rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts])
        in
        cloneBinders rhs_subst rhs_dicts                `thenSM` \ (rhs_subst', rhs_dicts') ->
        let
           inst_args = ty_args ++ map Var rhs_dicts'
 
                -- Figure out the type of the specialised function
-          spec_id_ty = mkForAllTys poly_tyvars (applyTypeToArgs rhs fn_type inst_args)
+          body_ty = applyTypeToArgs rhs fn_type inst_args
+          (lam_args, app_args)                 -- Add a dummy argument if body_ty is unlifted
+               | isUnLiftedType body_ty        -- C.f. WwLib.mkWorkerArgs
+               = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
+               | otherwise = (poly_tyvars, poly_tyvars)
+          spec_id_ty = mkPiTypes lam_args body_ty
        in
        newIdSM fn spec_id_ty                           `thenSM` \ spec_f ->
-       specExpr rhs_subst' (mkLams poly_tyvars body)   `thenSM` \ (spec_rhs, rhs_uds) ->       
+       specExpr rhs_subst' (mkLams lam_args body)      `thenSM` \ (spec_rhs, rhs_uds) ->       
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
@@ -889,7 +892,7 @@ specDefn subst calls (fn, rhs)
                                AlwaysActive
                                (poly_tyvars ++ rhs_dicts')
                                inst_args 
-                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+                               (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
@@ -1007,7 +1010,7 @@ mkCallUDs subst f args
        -- *don't* say what the value of the implicit param is!
   || not (spec_tys `lengthIs` n_tyvars)
   || not ( dicts   `lengthIs` n_dicts)
-  || maybeToBool (lookupRule (\act -> True) (substInScope subst) f args)
+  || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args)
        -- There's already a rule covering this call.  A typical case
        -- is where there's an explicit user-provided rule.  Then
        -- we don't want to create a specialised version 
@@ -1138,20 +1141,20 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 cloneBindSM subst (NonRec bndr rhs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndr') = substAndCloneId subst us bndr
+       (subst', bndr') = cloneIdBndr subst us bndr
     in
     returnUs (subst, subst', NonRec bndr' rhs)
 
 cloneBindSM subst (Rec pairs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndrs') = substAndCloneRecIds subst us (map fst pairs)
+       (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
     in
     returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
 cloneBinders subst bndrs
   = getUs      `thenUs` \ us ->
-    returnUs (substAndCloneIds subst us bndrs)
+    returnUs (cloneIdBndrs subst us bndrs)
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->