Extend eta reduction to work with casted arguments
authorsimonpj@microsoft.com <unknown>
Wed, 15 Sep 2010 22:12:29 +0000 (22:12 +0000)
committersimonpj@microsoft.com <unknown>
Wed, 15 Sep 2010 22:12:29 +0000 (22:12 +0000)
See Trac #4201, and
Note [Eta reduction with casted arguments]

Thanks to Louis Wasserman for suggesting this, and
implementing an early version of the patch

compiler/coreSyn/CoreUtils.lhs

index 103b294..8284702 100644 (file)
@@ -1231,18 +1231,55 @@ There are some particularly delicate points here:
 These delicacies are why we don't use exprIsTrivial and exprIsHNF here.
 Alas.
 
+Note [Eta reduction with casted arguments]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider  
+    (\(x:t3). f (x |> g)) :: t3 -> t2
+  where
+    f :: t1 -> t2
+    g :: t3 ~ t1
+This should be eta-reduced to
+
+    f |> (sym g -> t2)
+
+So we need to accumulate a coercion, pushing it inward (past
+variable arguments only) thus:
+   f (x |> co_arg) |> co  -->  (f |> (sym co_arg -> co)) x
+   f (x:t)         |> co  -->  (f |> (t -> co)) x
+   f @ a           |> co  -->  (f |> (forall a.co)) @ a
+   f @ (g:t1~t2)   |> co  -->  (f |> (t1~t2 => co)) @ (g:t1~t2)
+These are the equations for ok_arg.
+
+It's true that we could also hope to eta reduce these:
+    (\xy. (f x |> g) y)
+    (\xy. (f x y) |> g)
+But the simplifier pushes those casts outwards, so we don't
+need to address that here.
+
 \begin{code}
 tryEtaReduce :: [Var] -> CoreExpr -> Maybe CoreExpr
 tryEtaReduce bndrs body 
-  = go (reverse bndrs) body
+  = go (reverse bndrs) body (IdCo (exprType body))
   where
     incoming_arity = count isId bndrs
 
-    go (b : bs) (App fun arg) | ok_arg b arg = go bs fun       -- Loop round
-    go []       fun           | ok_fun fun   = Just fun                -- Success!
-    go _        _                           = Nothing          -- Failure!
+    go :: [Var]                   -- Binders, innermost first, types [a3,a2,a1]
+       -> CoreExpr         -- Of type tr
+       -> CoercionI        -- Of type tr ~ ts
+       -> Maybe CoreExpr   -- Of type a1 -> a2 -> a3 -> ts
+    -- See Note [Eta reduction with casted arguments]
+    -- for why we have an accumulating coercion
+    go [] fun co
+      | ok_fun fun = Just (mkCoerceI co fun)
+
+    go (b : bs) (App fun arg) co
+      | Just co' <- ok_arg b arg co
+      = go bs fun co'
 
-       -- Note [Eta reduction conditions]
+    go _ _ _  = Nothing                -- Failure!
+
+    ---------------
+    -- Note [Eta reduction conditions]
     ok_fun (App fun (Type ty)) 
        | not (any (`elemVarSet` tyVarsOfType ty) bndrs)
        =  ok_fun fun
@@ -1251,17 +1288,37 @@ tryEtaReduce bndrs body
        && (ok_fun_id fun_id || all ok_lam bndrs)
     ok_fun _fun = False
 
+    ---------------
     ok_fun_id fun = fun_arity fun >= incoming_arity
 
+    ---------------
     fun_arity fun            -- See Note [Arity care]
        | isLocalId fun && isLoopBreaker (idOccInfo fun) = 0
        | otherwise = idArity fun             
 
+    ---------------
     ok_lam v = isTyCoVar v || isDictId v
 
-    ok_arg b arg = varToCoreExpr b `cheapEqExpr` arg
+    ---------------
+    ok_arg :: Var              -- Of type bndr_t
+           -> CoreExpr          -- Of type arg_t
+           -> CoercionI         -- Of kind (t1~t2)
+           -> Maybe CoercionI   -- Of type (arg_t -> t1 ~  bndr_t -> t2)
+                               --   (and similarly for tyvars, coercion args)
+    -- See Note [Eta reduction with casted arguments]
+    ok_arg bndr (Type ty) co
+       | Just tv <- getTyVar_maybe ty
+       , bndr == tv  = Just (mkForAllTyCoI tv co)
+    ok_arg bndr (Var v) co
+       | bndr == v   = Just (mkFunTyCoI (IdCo (idType bndr)) co)
+    ok_arg bndr (Cast (Var v) co_arg) co
+       | bndr == v  = Just (mkFunTyCoI (ACo (mkSymCoercion co_arg)) co)
+       -- The simplifier combines multiple casts into one, 
+       -- so we can have a simple-minded pattern match here
+    ok_arg _ _ _ = Nothing
 \end{code}
 
+
 %************************************************************************
 %*                                                                     *
 \subsection{Determining non-updatable right-hand-sides}