diff --git a/py/vm.c b/py/vm.c
index b18a2b5e32a615155df474134a5425a7480f8d26..c629a61e2f686d1dc4ca800d606c04bdb5836930 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -48,14 +48,6 @@
 // top element.
 // Exception stack also grows up, top element is also pointed at.
 
-// Exception stack unwind reasons (WHY_* in CPython-speak)
-// TODO perhaps compress this to RETURN=0, JUMP>0, with number of unwinds
-// left to do encoded in the JUMP number
-typedef enum {
-    UNWIND_RETURN = 1,
-    UNWIND_JUMP,
-} mp_unwind_reason_t;
-
 #define DECODE_UINT \
     mp_uint_t unum = 0; \
     do { \
@@ -613,29 +605,18 @@ dispatch_loop:
                         mp_call_method_n_kw(3, 0, sp);
                         SET_TOP(mp_const_none);
                     } else if (MP_OBJ_IS_SMALL_INT(TOP())) {
-                        mp_int_t cause_val = MP_OBJ_SMALL_INT_VALUE(TOP());
-                        if (cause_val == UNWIND_RETURN) {
-                            // stack: (..., __exit__, ctx_mgr, ret_val, UNWIND_RETURN)
-                            mp_obj_t ret_val = sp[-1];
-                            sp[-1] = mp_const_none;
-                            sp[0] = mp_const_none;
-                            sp[1] = mp_const_none;
-                            mp_call_method_n_kw(3, 0, sp - 3);
-                            sp[-3] = ret_val;
-                            sp[-2] = MP_OBJ_NEW_SMALL_INT(UNWIND_RETURN);
-                        } else {
-                            assert(cause_val == UNWIND_JUMP);
-                            // stack: (..., __exit__, ctx_mgr, dest_ip, num_exc, UNWIND_JUMP)
-                            mp_obj_t dest_ip = sp[-2];
-                            mp_obj_t num_exc = sp[-1];
-                            sp[-2] = mp_const_none;
-                            sp[-1] = mp_const_none;
-                            sp[0] = mp_const_none;
-                            mp_call_method_n_kw(3, 0, sp - 4);
-                            sp[-4] = dest_ip;
-                            sp[-3] = num_exc;
-                            sp[-2] = MP_OBJ_NEW_SMALL_INT(UNWIND_JUMP);
-                        }
+                        // Getting here there are two distinct cases:
+                        //  - unwind return, stack: (..., __exit__, ctx_mgr, ret_val, SMALL_INT(-1))
+                        //  - unwind jump, stack:   (..., __exit__, ctx_mgr, dest_ip, SMALL_INT(num_exc))
+                        // For both cases we do exactly the same thing.
+                        mp_obj_t data = sp[-1];
+                        mp_obj_t cause = sp[0];
+                        sp[-1] = mp_const_none;
+                        sp[0] = mp_const_none;
+                        sp[1] = mp_const_none;
+                        mp_call_method_n_kw(3, 0, sp - 3);
+                        sp[-3] = data;
+                        sp[-2] = cause;
                         sp -= 2; // we removed (__exit__, ctx_mgr)
                     } else {
                         assert(mp_obj_is_exception_instance(TOP()));
@@ -680,10 +661,11 @@ unwind_jump:;
                             // of a "with" block contains the context manager info.
                             // We're going to run "finally" code as a coroutine
                             // (not calling it recursively). Set up a sentinel
-                            // on a stack so it can return back to us when it is
+                            // on the stack so it can return back to us when it is
                             // done (when WITH_CLEANUP or END_FINALLY reached).
-                            PUSH((mp_obj_t)unum); // push number of exception handlers left to unwind
-                            PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_JUMP)); // push sentinel
+                            // The sentinel is the number of exception handlers left to
+                            // unwind, which is a non-negative integer.
+                            PUSH(MP_OBJ_NEW_SMALL_INT(unum));
                             ip = exc_sp->handler; // get exception handler byte code address
                             exc_sp--; // pop exception handler
                             goto dispatch_loop; // run the exception handler
@@ -720,11 +702,14 @@ unwind_jump:;
                     } else if (MP_OBJ_IS_SMALL_INT(TOP())) {
                         // We finished "finally" coroutine and now dispatch back
                         // to our caller, based on TOS value
-                        mp_unwind_reason_t reason = MP_OBJ_SMALL_INT_VALUE(POP());
-                        if (reason == UNWIND_RETURN) {
+                        mp_int_t cause = MP_OBJ_SMALL_INT_VALUE(POP());
+                        if (cause < 0) {
+                            // A negative cause indicates unwind return
                             goto unwind_return;
                         } else {
-                            assert(reason == UNWIND_JUMP);
+                            // Otherwise it's an unwind jump and we must push as a raw
+                            // number the number of exception handlers to unwind
+                            PUSH((mp_obj_t)cause);
                             goto unwind_jump;
                         }
                     } else {
@@ -1101,7 +1086,7 @@ unwind_return:
                             // (not calling it recursively). Set up a sentinel
                             // on a stack so it can return back to us when it is
                             // done (when WITH_CLEANUP or END_FINALLY reached).
-                            PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_RETURN));
+                            PUSH(MP_OBJ_NEW_SMALL_INT(-1));
                             ip = exc_sp->handler;
                             exc_sp--;
                             goto dispatch_loop;