better comments/documentation
[coinductive-monad.git] / src / Computation / Termination.v
diff --git a/src/Computation/Termination.v b/src/Computation/Termination.v
new file mode 100644 (file)
index 0000000..1497fbf
--- /dev/null
@@ -0,0 +1,190 @@
+Require Import Computation.Monad.
+Require Import Coq.Logic.JMeq.
+
+Section Termination.
+
+  (** An inductive predicate proving that a given computation terminates with a particular value *)
+  Reserved Notation "c ! y" (at level 30).
+  Inductive TerminatesWith (A:Set) : #A -> A -> Prop :=
+  | TerminatesReturnWith :
+    forall (a:A),
+      (Return a)!a
+      
+  | TerminatesBindWith :
+    forall (B:Set) (b:B) (a:A) (f:B->#A) (c:#B),
+      c!b
+      -> (f b)!a
+      -> (c >>= f)!a
+     where "c ! y" := (TerminatesWith _ c y)
+      .
+  
+  (** An inductive predicate proving that a given computation terminates with /some/ value *)  
+  Reserved Notation "c !?" (at level 30).
+  Inductive Terminates (A:Set)(c:#A) : Prop :=
+    | Terminates_intro : (exists a:A, c!a) -> c!?
+      where "c !?" := (Terminates _ c).
+
+  (** A predicate indicating which computations might be subcomputations of a given computation *)
+  Inductive InvokedBy (A B C:Set) : #A -> #B -> #C -> Prop :=
+  | invokesPrev : forall
+    (z:#C)
+    (c:#B)
+    (f:B->#A),
+    InvokedBy A B C (@Bind A B f c) c z
+  | invokesFunc : forall
+    (c:#C)
+    (b:C)
+    (f:C->#A)
+    (pf:#A->#B)
+    (eqpf:A=B)
+    (_:JMeq (pf (f b)) (f b))
+    (_:c!b),
+    InvokedBy A B C (@Bind A C f c) (pf (f b)) c.
+
+  (** A predicate asserting that it is safe to evaluate a computation (this is the single-constructor Prop type) *)
+  Inductive Safe : forall (A:Set) (c:#A), Prop :=
+    Safe_intro :
+    forall (A:Set) (c:#A),
+      (forall (B C:Set) (c':#B)(z:#C), InvokedBy A B C c c' z -> Safe B c')
+      -> Safe A c.
+
+  (** Inversion principle for Safe *)
+  Definition Safe_inv
+    : forall (A B C:Set)(c:#A)(_:Safe A c)(c':#B)(z:#C)(_:InvokedBy A B C c c' z), Safe B c'.
+    destruct 1.
+    apply (H B).
+  Defined.
+
+  Notation "{ c }!" := {a:_|TerminatesWith _ c a} (at level 5).
+  Notation "'!Let' x := y 'in' z" := ((fun x => z)y)(at level 100).
+  Definition eval' CC cc (Z:Set) (z:#Z) (s:Safe CC cc) : {cc}!.
+    refine(
+      !Let eval_one_step :=
+        fun C c Z z =>
+          match c return (forall PRED pred Z z, InvokedBy C PRED Z c pred z -> {pred}!) -> {c}! with
+            | Return x => fun _ => exist _ x _
+            | Bind CN f cn =>
+              fun eval_pred =>
+                match eval_pred CN cn Z z (invokesPrev C CN Z z cn f) with
+                  | exist b pf =>
+                    match eval_pred C (f b) CN cn _ with
+                      | exist a' pf' => exist _ a' _
+                    end
+                end
+          end
+          in
+          fix eval_all C c Z z (s:Safe C c) {struct s} : {c}! :=
+          eval_one_step C c Z z (fun C' c' Z z icc => eval_all C' c' Z z (Safe_inv C C' Z c s c' z icc))
+    ).
+    
+    constructor.
+
+    refine (invokesFunc C C CN cn b f (fun x:#C=>x) _ _ _).
+    auto.
+    auto.
+    assumption.
+    
+    apply (TerminatesBindWith C CN b a' f cn).
+    assumption.
+    assumption.
+  Defined.
+
+  (** A lemma to help apply JMeq *)
+  Theorem jmeq_lemma : forall (A1 A2 B:Set)(c1:#A1)(c2:#A2)(f1:A1->#B)(f2:A2->#B),
+    ((c1>>=f1)=(c2>>=f2))
+    -> (JMeq c1 c2) /\ (JMeq f1 f2) /\ (A1=A2).
+    intros.
+    inversion H.
+    split.
+    dependent rewrite H3.
+    simpl.
+    auto.
+    split.
+    dependent rewrite H2.
+    simpl.
+    auto.
+    auto.
+  Qed.
+
+  (** If we can prove that a given computation terminates with two different values, they must be the same *)
+  Lemma computation_is_deterministic :
+    forall (A:Set) (c:#A) (x y:A), c!x -> c!y -> x=y.
+    intros.
+    generalize H0.
+    clear H0.
+    induction H.
+    intros.
+    inversion H0.
+    auto.
+
+    intros.
+    simple inversion H1.
+    inversion H2.
+    
+    apply jmeq_lemma in H4.
+    destruct H4.
+    destruct H3.
+    subst.
+    apply JMeq_eq in H2.
+    apply JMeq_eq in H3.
+    subst.
+    intros.
+    apply IHTerminatesWith2.
+    apply IHTerminatesWith1 in H2.
+    subst.
+    auto.
+  Qed.
+
+  (** Any terminating computation is Safe *)
+  Theorem termination_is_safe : forall (A:Set) (c:#A), c!? -> Safe A c.
+    intros.
+    destruct H.
+    destruct H.
+    induction H.
+    apply Safe_intro.
+    intros.
+    inversion H.
+
+    apply Safe_intro.
+    intros.
+    simple inversion H1.
+    apply jmeq_lemma in H2.
+    destruct H2.
+    destruct H5.
+    subst.
+    apply JMeq_eq in H2.
+    subst.
+    auto.
+
+    intros.
+    apply jmeq_lemma in H4.
+    destruct H4.
+    destruct H7.
+    rewrite <- H5.
+    rewrite <- H5 in H1.
+    clear H5.
+    generalize H2.
+    clear H2.
+    subst.
+    apply JMeq_eq in H4.
+    subst.
+
+    assert (b=b0).
+    apply (computation_is_deterministic B c b b0 H H3).
+    subst.
+    apply JMeq_eq in H7.
+    subst.
+
+    intros.
+    apply JMeq_eq in H2.
+    rewrite H2.
+    apply IHTerminatesWith2.
+  Defined.
+
+End Termination.
+
+Implicit Arguments Terminates [A].
+Implicit Arguments TerminatesReturnWith [A].
+Implicit Arguments TerminatesBindWith [A].
+Implicit Arguments eval' [CC].
+