HaskFlattener: simplify the flattener
[coq-hetmet.git] / src / HaskFlattener.v
index 1f74250..5ebd6ba 100644 (file)
@@ -43,6 +43,9 @@ Set Printing Width 130.
  *)
 Section HaskFlattener.
 
+  Definition getlev {Γ}{κ}(lht:LeveledHaskType Γ κ) : HaskLevel Γ :=
+    match lht with t @@ l => l end.
+
   Definition arrange :
     forall {T} (Σ:Tree ??T) (f:T -> bool),
       Arrange Σ (dropT (mkFlags (liftBoolFunc false f) Σ),,( (dropT (mkFlags (liftBoolFunc false (bnot ○ f)) Σ)))).
@@ -195,8 +198,11 @@ Section HaskFlattener.
   Context {prodTy : forall TV, RawHaskType TV ECKind  -> RawHaskType TV ★  -> RawHaskType TV ★ -> RawHaskType TV ★ }.
   Context {gaTy   : forall TV, RawHaskType TV ECKind  -> RawHaskType TV ★ -> RawHaskType TV ★  -> RawHaskType TV ★ }.
 
+  Definition ga_mk_tree' {TV}(ec:RawHaskType TV ECKind)(tr:Tree ??(RawHaskType TV ★)) : RawHaskType TV ★ :=
+    reduceTree (unitTy TV ec) (prodTy TV ec) tr.
+
   Definition ga_mk_tree {Γ}(ec:HaskType Γ ECKind)(tr:Tree ??(HaskType Γ ★)) : HaskType Γ ★ :=
-    fun TV ite => reduceTree (unitTy TV (ec TV ite)) (prodTy TV (ec TV ite)) (mapOptionTree (fun x => x TV ite) tr).
+    fun TV ite => ga_mk_tree' (ec TV ite) (mapOptionTree (fun x => x TV ite) tr).
 
   Definition ga_mk {Γ}(ec:HaskType Γ ECKind )(ant suc:Tree ??(HaskType Γ ★)) : HaskType Γ ★ :=
     fun TV ite => gaTy TV (ec TV ite) (ga_mk_tree ec ant TV ite) (ga_mk_tree ec suc TV ite).
@@ -215,7 +221,11 @@ Section HaskFlattener.
     | TCon       tc         => TCon      tc
     | TCoerc _ t1 t2 t      => TCoerc    (flatten_rawtype t1) (flatten_rawtype t2) (flatten_rawtype t)
     | TArrow                => TArrow
-    | TCode      v e        => gaTy TV v (unitTy TV v) (flatten_rawtype e)
+    | TCode     ec e        => let     e'   := flatten_rawtype e
+                               (* in let args := take_arg_types e'
+                                in let t    := drop_arg_types e'
+                                in     gaTy TV ec (ga_mk_tree' ec (unleaves args)) t*)
+                                in gaTy TV ec (unitTy TV ec) e'
     | TyFunApp  tfc kl k lt => TyFunApp tfc kl k (flatten_rawtype_list _ lt)
     end
     with flatten_rawtype_list {TV}(lk:list Kind)(exp:@RawHaskTypeList TV lk) : @RawHaskTypeList TV lk :=
@@ -225,20 +235,16 @@ Section HaskFlattener.
     end.
 
   Definition flatten_type {Γ}{κ}(ht:HaskType Γ κ) : HaskType Γ κ :=
-    fun TV ite =>
-      flatten_rawtype (ht TV ite).
+    fun TV ite => flatten_rawtype (ht TV ite).
 
-  Fixpoint flatten_leveled_type' {Γ}(ht:HaskType Γ ★)(lev:HaskLevel Γ) : HaskType Γ ★ :=
+  Fixpoint levels_to_tcode {Γ}(ht:HaskType Γ ★)(lev:HaskLevel Γ) : HaskType Γ ★ :=
     match lev with
-      | nil      => flatten_type ht
-      | ec::lev' => @ga_mk _ (v2t ec) [] [flatten_leveled_type' ht lev']
+      | nil      => ht
+      | ec::lev' => fun TV ite => TCode (v2t ec TV ite) (levels_to_tcode ht lev' TV ite)
     end.
 
   Definition flatten_leveled_type {Γ}(ht:LeveledHaskType Γ ★) : LeveledHaskType Γ ★ :=
-    match ht with
-      htt @@ lev =>
-      flatten_leveled_type' htt lev @@ nil
-    end.
+    flatten_type (levels_to_tcode (unlev ht) (getlev ht)) @@ nil.
 
   (* AXIOMS *)
 
@@ -638,6 +644,9 @@ Section HaskFlattener.
     inversion e; subst.
     simpl.
     apply nd_rule.
+    unfold flatten_leveled_type.
+    simpl.
+    unfold flatten_type.
     apply RVar.
     simpl.
     apply ga_id.
@@ -677,7 +686,6 @@ Section HaskFlattener.
       simpl.
       destruct (General.list_eq_dec h0 (ec :: nil)).
         simpl.
-        unfold flatten_leveled_type'.
         rewrite e.
         apply nd_id.
         simpl.
@@ -776,7 +784,7 @@ Section HaskFlattener.
       simpl.
       destruct lev.
       simpl.
-      change ([flatten_type (<[ ec |- t ]>) @@  nil])
+      change ([flatten_leveled_type (<[ ec |- t ]> @@  nil)])
         with ([ga_mk (v2t ec) [] [flatten_type  t] @@  nil]).
       eapply nd_comp; [ apply arrange_esc | idtac ].
       set (decide_tree_empty (take_lev (ec :: nil) succ)) as q'.
@@ -813,6 +821,8 @@ Section HaskFlattener.
     destruct case_RLit.
       simpl.
       destruct l0; simpl.
+        unfold flatten_leveled_type.
+        simpl.
         rewrite literal_types_unchanged.
           apply nd_rule; apply RLit.
         unfold take_lev; simpl.
@@ -854,7 +864,8 @@ Section HaskFlattener.
           apply nd_rule.
           apply q.
           apply INil.
-        apply nd_rule; rewrite globals_do_not_have_code_types.
+        unfold flatten_leveled_type. simpl.
+          apply nd_rule; rewrite globals_do_not_have_code_types.
           apply RGlobal.
       apply (Prelude_error "found RGlobal at depth >0; globals should never appear inside code brackets unless escaped").
 
@@ -890,7 +901,11 @@ Section HaskFlattener.
       simpl.
 
       destruct lev as [|ec lev]. simpl. apply nd_rule.
-        replace (flatten_type  (tx ---> te)) with ((flatten_type  tx) ---> (flatten_type  te)).
+        unfold flatten_leveled_type at 4.
+        unfold flatten_leveled_type at 2.
+        simpl.
+        replace (flatten_type (tx ---> te))
+          with (flatten_type tx ---> flatten_type te).
         apply RApp.
         reflexivity.
 
@@ -944,6 +959,8 @@ Section HaskFlattener.
         
     destruct case_RAppT.
       simpl. destruct lev; simpl.
+      unfold flatten_leveled_type.
+      simpl.
       rewrite flatten_commutes_with_HaskTAll.
       rewrite flatten_commutes_with_substT.
       apply nd_rule.
@@ -954,6 +971,9 @@ Section HaskFlattener.
 
     destruct case_RAbsT.
       simpl. destruct lev; simpl.
+      unfold flatten_leveled_type at 4.
+      unfold flatten_leveled_type at 2.
+      simpl.
       rewrite flatten_commutes_with_HaskTAll.
       rewrite flatten_commutes_with_HaskTApp.
       eapply nd_comp; [ idtac | eapply nd_rule; eapply RAbsT ].
@@ -982,6 +1002,8 @@ Section HaskFlattener.
 
     destruct case_RAppCo.
       simpl. destruct lev; simpl.
+      unfold flatten_leveled_type at 4.
+      unfold flatten_leveled_type at 2.
       unfold flatten_type.
       simpl.
       apply nd_rule.