diff --git a/py/compile.c b/py/compile.c
index 7daf91103516c89db63dcd1e2809992f564e4030..98c09b2107cafb60599a49a236464fdb7c149c70 100644
--- a/py/compile.c
+++ b/py/compile.c
@@ -1766,46 +1766,71 @@ STATIC void compile_async_with_stmt_helper(compiler_t *comp, int n, mp_parse_nod
         // no more pre-bits, compile the body of the with
         compile_node(comp, body);
     } else {
-        uint try_exception_label = comp_next_label(comp);
-        uint no_reraise_label = comp_next_label(comp);
-        uint try_else_label = comp_next_label(comp);
-        uint end_label = comp_next_label(comp);
-        qstr context;
+        uint l_finally_block = comp_next_label(comp);
+        uint l_aexit_no_exc = comp_next_label(comp);
+        uint l_ret_unwind_jump = comp_next_label(comp);
+        uint l_end = comp_next_label(comp);
 
         if (MP_PARSE_NODE_IS_STRUCT_KIND(nodes[0], PN_with_item)) {
             // this pre-bit is of the form "a as b"
             mp_parse_node_struct_t *pns = (mp_parse_node_struct_t*)nodes[0];
             compile_node(comp, pns->nodes[0]);
-            context = MP_PARSE_NODE_LEAF_ARG(pns->nodes[0]);
-            compile_store_id(comp, context);
-            compile_load_id(comp, context);
+            EMIT(dup_top);
             compile_await_object_method(comp, MP_QSTR___aenter__);
             c_assign(comp, pns->nodes[1], ASSIGN_STORE);
         } else {
             // this pre-bit is just an expression
             compile_node(comp, nodes[0]);
-            context = MP_PARSE_NODE_LEAF_ARG(nodes[0]);
-            compile_store_id(comp, context);
-            compile_load_id(comp, context);
+            EMIT(dup_top);
             compile_await_object_method(comp, MP_QSTR___aenter__);
             EMIT(pop_top);
         }
 
-        compile_load_id(comp, context);
-        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
+        // To keep the Python stack size down, and because we can't access values on
+        // this stack further down than 3 elements (via rot_three), we don't preload
+        // __aexit__ (as per normal with) but rather wait until we need it below.
 
-        EMIT_ARG(setup_block, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT);
+        // Start the try-finally statement
+        EMIT_ARG(setup_block, l_finally_block, MP_EMIT_SETUP_BLOCK_FINALLY);
         compile_increase_except_level(comp);
-        // compile additional pre-bits and the body
+
+        // Compile any additional pre-bits of the "async with", and also the body
+        EMIT_ARG(adjust_stack_size, 3); // stack adjust for possible UNWIND_JUMP state
         compile_async_with_stmt_helper(comp, n - 1, nodes + 1, body);
-        // finish this with block
+        EMIT_ARG(adjust_stack_size, -3);
+
+        // Finish the "try" block
         EMIT(pop_block);
-        EMIT_ARG(jump, try_else_label); // jump over exception handler
 
-        EMIT_ARG(label_assign, try_exception_label); // start of exception handler
-        EMIT(start_except_handler);
+        // At this point, after the with body has executed, we have 3 cases:
+        // 1. no exception, we just fall through to this point; stack: (..., ctx_mgr)
+        // 2. exception propagating out, we get to the finally block; stack: (..., ctx_mgr, exc)
+        // 3. return or unwind jump, we get to the finally block; stack: (..., ctx_mgr, X, INT)
+
+        // Handle case 1: call __aexit__
+        // Stack: (..., ctx_mgr)
+        EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // to tell end_finally there's no exception
+        EMIT(rot_two);
+        EMIT_ARG(jump, l_aexit_no_exc); // jump to code below to call __aexit__
+
+        // Start of "finally" block
+        // At this point we have case 2 or 3, we detect which one by the TOS being an exception or not
+        EMIT_ARG(label_assign, l_finally_block);
 
-        // at this point the stack contains: ..., __aexit__, self, exc
+        // Detect if TOS an exception or not
+        EMIT(dup_top);
+        EMIT_LOAD_GLOBAL(MP_QSTR_Exception);
+        EMIT_ARG(binary_op, MP_BINARY_OP_EXCEPTION_MATCH);
+        EMIT_ARG(pop_jump_if, false, l_ret_unwind_jump); // if not an exception then we have case 3
+
+        // Handle case 2: call __aexit__ and either swallow or re-raise the exception
+        // Stack: (..., ctx_mgr, exc)
+        EMIT(dup_top);
+        EMIT(rot_three);
+        EMIT(rot_two);
+        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
+        EMIT(rot_three);
+        EMIT(rot_three);
         EMIT(dup_top);
         #if MICROPY_CPYTHON_COMPAT
         EMIT_ARG(attr, MP_QSTR___class__, MP_EMIT_ATTR_LOAD); // get type(exc)
@@ -1816,32 +1841,38 @@ STATIC void compile_async_with_stmt_helper(compiler_t *comp, int n, mp_parse_nod
         #endif
         EMIT(rot_two);
         EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // dummy traceback value
-        // at this point the stack contains: ..., __aexit__, self, type(exc), exc, None
+        // Stack: (..., exc, __aexit__, ctx_mgr, type(exc), exc, None)
         EMIT_ARG(call_method, 3, 0, 0);
-
         compile_yield_from(comp);
-        EMIT_ARG(pop_jump_if, true, no_reraise_label);
-        EMIT_ARG(raise_varargs, 0);
-
-        EMIT_ARG(label_assign, no_reraise_label);
-        EMIT(pop_except);
-        EMIT_ARG(jump, end_label);
-
-        EMIT_ARG(adjust_stack_size, 3); // adjust for __aexit__, self, exc
-        compile_decrease_except_level(comp);
-        EMIT(end_finally);
-        EMIT(end_except_handler);
-
-        EMIT_ARG(label_assign, try_else_label); // start of try-else handler
+        EMIT_ARG(pop_jump_if, false, l_end);
+        EMIT(pop_top); // pop exception
+        EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // replace with None to swallow exception
+        EMIT_ARG(jump, l_end);
+        EMIT_ARG(adjust_stack_size, 2);
+
+        // Handle case 3: call __aexit__
+        // Stack: (..., ctx_mgr, X, INT)
+        EMIT_ARG(label_assign, l_ret_unwind_jump);
+        EMIT(rot_three);
+        EMIT(rot_three);
+        EMIT_ARG(label_assign, l_aexit_no_exc);
+        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
         EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE);
         EMIT(dup_top);
         EMIT(dup_top);
         EMIT_ARG(call_method, 3, 0, 0);
         compile_yield_from(comp);
         EMIT(pop_top);
+        EMIT_ARG(adjust_stack_size, -1);
 
-        EMIT_ARG(label_assign, end_label);
-
+        // End of "finally" block
+        // Stack can have one of three configurations:
+        // a. (..., None) - from either case 1, or case 2 with swallowed exception
+        // b. (..., exc) - from case 2 with re-raised exception
+        // c. (..., X, INT) - from case 3
+        EMIT_ARG(label_assign, l_end);
+        compile_decrease_except_level(comp);
+        EMIT(end_finally);
     }
 }
 
diff --git a/tests/basics/async_with_break.py b/tests/basics/async_with_break.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bcbccb025e61f8662554fdf76859c78b4d30c3
--- /dev/null
+++ b/tests/basics/async_with_break.py
@@ -0,0 +1,59 @@
+# test async with, escaped by a break
+
+class AContext:
+    async def __aenter__(self):
+        print('enter')
+        return 1
+    async def __aexit__(self, exc_type, exc, tb):
+        print('exit', exc_type, exc)
+
+async def f1():
+    while 1:
+        async with AContext():
+            print('body')
+            break
+            print('no 1')
+        print('no 2')
+
+o = f1()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
+
+async def f2():
+    while 1:
+        try:
+            async with AContext():
+                print('body')
+                break
+                print('no 1')
+        finally:
+            print('finally')
+        print('no 2')
+
+o = f2()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
+
+async def f3():
+    while 1:
+        try:
+            try:
+                async with AContext():
+                    print('body')
+                    break
+                    print('no 1')
+            finally:
+                print('finally inner')
+        finally:
+            print('finally outer')
+        print('no 2')
+
+o = f3()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
diff --git a/tests/basics/async_with_break.py.exp b/tests/basics/async_with_break.py.exp
new file mode 100644
index 0000000000000000000000000000000000000000..d077a88fad0e4a8ce5f6092965f3f1fa1c535474
--- /dev/null
+++ b/tests/basics/async_with_break.py.exp
@@ -0,0 +1,15 @@
+enter
+body
+exit None None
+finished
+enter
+body
+exit None None
+finally
+finished
+enter
+body
+exit None None
+finally inner
+finally outer
+finished
diff --git a/tests/basics/async_with_return.py b/tests/basics/async_with_return.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af88b839f88c682443ac3b1c8ee5b3de6cd26f5
--- /dev/null
+++ b/tests/basics/async_with_return.py
@@ -0,0 +1,50 @@
+# test async with, escaped by a return
+
+class AContext:
+    async def __aenter__(self):
+        print('enter')
+        return 1
+    async def __aexit__(self, exc_type, exc, tb):
+        print('exit', exc_type, exc)
+
+async def f1():
+    async with AContext():
+        print('body')
+        return
+
+o = f1()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
+
+async def f2():
+    try:
+        async with AContext():
+            print('body')
+            return
+    finally:
+        print('finally')
+
+o = f2()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
+
+async def f3():
+    try:
+        try:
+            async with AContext():
+                print('body')
+                return
+        finally:
+            print('finally inner')
+    finally:
+        print('finally outer')
+
+o = f3()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
diff --git a/tests/basics/async_with_return.py.exp b/tests/basics/async_with_return.py.exp
new file mode 100644
index 0000000000000000000000000000000000000000..d077a88fad0e4a8ce5f6092965f3f1fa1c535474
--- /dev/null
+++ b/tests/basics/async_with_return.py.exp
@@ -0,0 +1,15 @@
+enter
+body
+exit None None
+finished
+enter
+body
+exit None None
+finally
+finished
+enter
+body
+exit None None
+finally inner
+finally outer
+finished
diff --git a/tests/run-tests b/tests/run-tests
index cfd7c40379a92a16dbb6bf2915e7384b39fb0820..e1b594edfdc7ce905ae1bac04964a18a6d73b428 100755
--- a/tests/run-tests
+++ b/tests/run-tests
@@ -338,7 +338,7 @@ def run_tests(pyb, tests, args, base_path="."):
     if args.emit == 'native':
         skip_tests.update({'basics/%s.py' % t for t in 'gen_yield_from gen_yield_from_close gen_yield_from_ducktype gen_yield_from_exc gen_yield_from_executing gen_yield_from_iter gen_yield_from_send gen_yield_from_stopped gen_yield_from_throw gen_yield_from_throw2 gen_yield_from_throw3 generator1 generator2 generator_args generator_close generator_closure generator_exc generator_pend_throw generator_return generator_send'.split()}) # require yield
         skip_tests.update({'basics/%s.py' % t for t in 'bytes_gen class_store_class globals_del string_join'.split()}) # require yield
-        skip_tests.update({'basics/async_%s.py' % t for t in 'def await await2 for for2 with with2'.split()}) # require yield
+        skip_tests.update({'basics/async_%s.py' % t for t in 'def await await2 for for2 with with2 with_break with_return'.split()}) # require yield
         skip_tests.update({'basics/%s.py' % t for t in 'try_reraise try_reraise2'.split()}) # require raise_varargs
         skip_tests.update({'basics/%s.py' % t for t in 'with_break with_continue with_return'.split()}) # require complete with support
         skip_tests.add('basics/array_construct2.py') # requires generators