STM invariants
[ghc-hetmet.git] / rts / Exception.cmm
index 0c1b664..1104706 100644 (file)
@@ -344,12 +344,25 @@ retry_pop_stack:
     if (frame_type == ATOMICALLY_FRAME) {
       /* The exception has reached the edge of a memory transaction.  Check that 
        * the transaction is valid.  If not then perhaps the exception should
-       * not have been thrown: re-run the transaction */
-      W_ trec;
+       * not have been thrown: re-run the transaction.  "trec" will either be
+       * a top-level transaction running the atomic block, or a nested 
+       * transaction running an invariant check.  In the latter case we
+       * abort and de-allocate the top-level transaction that encloses it
+       * as well (we could just abandon its transaction record, but this makes
+       * sure it's marked as aborted and available for re-use). */
+      W_ trec, outer;
       W_ r;
       trec = StgTSO_trec(CurrentTSO);
       r = foreign "C" stmValidateNestOfTransactions(trec "ptr");
+      "ptr" outer = foreign "C" stmGetEnclosingTRec(trec "ptr") [];
       foreign "C" stmAbortTransaction(MyCapability() "ptr", trec "ptr");
+      foreign "C" stmFreeAbortedTRec(MyCapability() "ptr", trec "ptr");
+
+      if (outer != NO_TREC) {
+        foreign "C" stmAbortTransaction(MyCapability() "ptr", outer "ptr");
+        foreign "C" stmFreeAbortedTRec(MyCapability() "ptr", outer "ptr");
+      }
+
       StgTSO_trec(CurrentTSO) = NO_TREC;
       if (r != 0) {
         // Transaction was valid: continue searching for a catch frame
@@ -400,6 +413,9 @@ retry_pop_stack:
      * If exceptions were unblocked, arrange that they are unblocked
      * again after executing the handler by pushing an
      * unblockAsyncExceptions_ret stack frame.
+     *
+     * If we've reached an STM catch frame then roll back the nested
+     * transaction we were using.
      */
     W_ frame;
     frame = Sp;
@@ -410,6 +426,12 @@ retry_pop_stack:
         Sp(0) = stg_unblockAsyncExceptionszh_ret_info;
       }
     } else {
+      W_ trec, outer;
+      trec = StgTSO_trec(CurrentTSO);
+      "ptr" outer = foreign "C" stmGetEnclosingTRec(trec "ptr") [];
+      foreign "C" stmAbortTransaction(MyCapability() "ptr", trec "ptr") [];
+      foreign "C" stmFreeAbortedTRec(MyCapability() "ptr", trec "ptr") [];
+      StgTSO_trec(CurrentTSO) = outer;
       Sp = Sp + SIZEOF_StgCatchSTMFrame;
     }