diff --git a/py/emitbc.c b/py/emitbc.c
index 8d482afe52678b1cead2b4076e0b1046060e2539..b9304fd82d5ddc074845ed1ab6106c003d602c82 100644
--- a/py/emitbc.c
+++ b/py/emitbc.c
@@ -680,12 +680,15 @@ void mp_emit_bc_unwind_jump(emit_t *emit, mp_uint_t label, mp_uint_t except_dept
 }
 
 void mp_emit_bc_setup_with(emit_t *emit, mp_uint_t label) {
-    emit_bc_pre(emit, 7);
+    // TODO We can probably optimise the amount of needed stack space, since
+    // we don't actually need 4 slots during the entire with block, only in
+    // the cleanup handler in certain cases.  It needs some thinking.
+    emit_bc_pre(emit, 4);
     emit_write_bytecode_byte_unsigned_label(emit, MP_BC_SETUP_WITH, label);
 }
 
 void mp_emit_bc_with_cleanup(emit_t *emit) {
-    emit_bc_pre(emit, -7);
+    emit_bc_pre(emit, -4);
     emit_write_bytecode_byte(emit, MP_BC_WITH_CLEANUP);
 }
 
diff --git a/py/vm.c b/py/vm.c
index 55203b0748685c53113e0dcd62b564a7ba58d6b4..1af3636f6bff489d267c66601ad805811720139d 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -548,68 +548,82 @@ dispatch_loop:
 
                 ENTRY(MP_BC_SETUP_WITH): {
                     MARK_EXC_IP_SELECTIVE();
+                    // stack: (..., ctx_mgr)
                     mp_obj_t obj = TOP();
-                    SET_TOP(mp_load_attr(obj, MP_QSTR___exit__));
-                    mp_load_method(obj, MP_QSTR___enter__, sp + 1);
-                    mp_obj_t ret = mp_call_method_n_kw(0, 0, sp + 1);
+                    mp_load_method(obj, MP_QSTR___exit__, sp);
+                    mp_load_method(obj, MP_QSTR___enter__, sp + 2);
+                    mp_obj_t ret = mp_call_method_n_kw(0, 0, sp + 2);
+                    sp += 1;
                     PUSH_EXC_BLOCK(1);
                     PUSH(ret);
+                    // stack: (..., __exit__, ctx_mgr, as_value)
                     DISPATCH();
                 }
 
                 ENTRY(MP_BC_WITH_CLEANUP): {
                     MARK_EXC_IP_SELECTIVE();
                     // Arriving here, there's "exception control block" on top of stack,
-                    // and __exit__ bound method underneath it. Bytecode calls __exit__,
+                    // and __exit__ method (with self) underneath it. Bytecode calls __exit__,
                     // and "deletes" it off stack, shifting "exception control block"
                     // to its place.
-                    static const mp_obj_t no_exc[] = {mp_const_none, mp_const_none, mp_const_none};
                     if (TOP() == mp_const_none) {
-                        sp--;
-                        mp_obj_t obj = TOP();
+                        // stack: (..., __exit__, ctx_mgr, None)
+                        sp[1] = mp_const_none;
+                        sp[2] = mp_const_none;
+                        sp -= 2;
+                        mp_call_method_n_kw(3, 0, sp);
                         SET_TOP(mp_const_none);
-                        mp_call_function_n_kw(obj, 3, 0, no_exc);
                     } else if (MP_OBJ_IS_SMALL_INT(TOP())) {
                         mp_int_t cause_val = MP_OBJ_SMALL_INT_VALUE(TOP());
                         if (cause_val == UNWIND_RETURN) {
-                            mp_call_function_n_kw(sp[-2], 3, 0, no_exc);
+                            // stack: (..., __exit__, ctx_mgr, ret_val, UNWIND_RETURN)
+                            mp_obj_t ret_val = sp[-1];
+                            sp[-1] = mp_const_none;
+                            sp[0] = mp_const_none;
+                            sp[1] = mp_const_none;
+                            mp_call_method_n_kw(3, 0, sp - 3);
+                            sp[-3] = ret_val;
+                            sp[-2] = MP_OBJ_NEW_SMALL_INT(UNWIND_RETURN);
                         } else {
                             assert(cause_val == UNWIND_JUMP);
-                            mp_call_function_n_kw(sp[-3], 3, 0, no_exc);
-                            // Pop __exit__ boundmethod at sp[-3]
-                            sp[-3] = sp[-2];
+                            // stack: (..., __exit__, ctx_mgr, dest_ip, num_exc, UNWIND_JUMP)
+                            mp_obj_t dest_ip = sp[-2];
+                            mp_obj_t num_exc = sp[-1];
+                            sp[-2] = mp_const_none;
+                            sp[-1] = mp_const_none;
+                            sp[0] = mp_const_none;
+                            mp_call_method_n_kw(3, 0, sp - 4);
+                            sp[-4] = dest_ip;
+                            sp[-3] = num_exc;
+                            sp[-2] = MP_OBJ_NEW_SMALL_INT(UNWIND_JUMP);
                         }
-                        sp[-2] = sp[-1]; // copy retval down
-                        sp[-1] = sp[0]; // copy cause down
-                        sp--; // discard top value (was cause)
+                        sp -= 2; // we removed (__exit__, ctx_mgr)
                     } else {
                         assert(mp_obj_is_exception_type(TOP()));
+                        // stack: (..., __exit__, ctx_mgr, traceback, exc_val, exc_type)
                         // Need to pass (sp[0], sp[-1], sp[-2]) as arguments so must reverse the
                         // order of these on the value stack (don't want to create a temporary
                         // array because it increases stack footprint of the VM).
                         mp_obj_t obj = sp[-2];
                         sp[-2] = sp[0];
                         sp[0] = obj;
-                        mp_obj_t ret_value = mp_call_function_n_kw(sp[-3], 3, 0, &sp[-2]);
+                        mp_obj_t ret_value = mp_call_method_n_kw(3, 0, sp - 4);
                         if (mp_obj_is_true(ret_value)) {
-                            // This is what CPython does
-                            //PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_SILENCED));
-                            // But what we need to do is - pop exception from value stack...
+                            // We need to silence/swallow the exception.  This is done
+                            // by popping the exception and the __exit__ handler and
+                            // replacing it with None, which signals END_FINALLY to just
+                            // execute the finally handler normally.
                             sp -= 4;
-                            // ... pop "with" exception handler, and signal END_FINALLY
-                            // to just execute finally handler normally (by pushing None
-                            // on value stack)
+                            SET_TOP(mp_const_none);
                             assert(exc_sp >= exc_stack);
                             POP_EXC_BLOCK();
-                            PUSH(mp_const_none);
                         } else {
-                            // Pop __exit__ boundmethod at sp[-3], remembering that top 3 values
-                            // are reversed.
-                            sp[-3] = sp[0];
-                            obj = sp[-2];
-                            sp[-2] = sp[-1];
-                            sp[-1] = obj;
-                            sp--;
+                            // We need to re-raise the exception.  We pop __exit__ handler
+                            // and copy the 3 exception values down (remembering that they
+                            // are reversed due to above code).
+                            sp[-4] = sp[0];
+                            sp[-3] = sp[-1];
+                            sp -= 2;
                         }
                     }
                     DISPATCH();
diff --git a/tests/cmdline/cmd_showbc.py.exp b/tests/cmdline/cmd_showbc.py.exp
index a7088894e711bf05e3a4cfed4f847d1c1ef43d84..48bf628affb07bcc39651e696e3439351f85bd0e 100644
--- a/tests/cmdline/cmd_showbc.py.exp
+++ b/tests/cmdline/cmd_showbc.py.exp
@@ -29,7 +29,7 @@ Raw bytecode (code_info_size=\\d\+, bytecode_size=\\d\+):
 ########
 \.\+5b
 arg names:
-(N_STATE 25)
+(N_STATE 22)
 (N_EXC_STACK 2)
 (INIT_CELL 14)
 (INIT_CELL 15)