From d8dc918deb8d4b13b8919706f9f208542c9ef2e6 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Sat, 23 Jun 2018 22:32:09 +1000
Subject: [PATCH] py/compile: Handle return/break/continue correctly in async
 with.

Before this patch the context manager's __aexit__() method would not be
executed if a return/break/continue statement was used to exit an async
with block.  async with now has the same semantics as normal with.

The fix here applies purely to the compiler, and does not modify the
runtime at all. It might (eventually) be better to define new bytecode(s)
to handle async with (and maybe other async constructs) in a cleaner, more
efficient way.

One minor drawback with addressing this issue purely in the compiler is
that it wasn't possible to get 100% CPython semantics.  The thing that is
different here to CPython is that the __aexit__ method is not looked up in
the context manager until it is needed, which is after the body of the
async with statement has executed.  So if a context manager doesn't have
__aexit__ then CPython raises an exception before the async with is
executed, whereas uPy will raise it after it is executed.  Note that
__aenter__ is looked up at the beginning in uPy because it needs to be
called straightaway, so if the context manager isn't a context manager then
it'll still raise an exception at the same location as CPython.  The only
difference is if the context manager has the __aenter__ method but not the
__aexit__ method, then in that case uPy has different behaviour.  But this
is a very minor, and acceptable, difference.
---
 py/compile.c                          | 105 +++++++++++++++++---------
 tests/basics/async_with_break.py      |  59 +++++++++++++++
 tests/basics/async_with_break.py.exp  |  15 ++++
 tests/basics/async_with_return.py     |  50 ++++++++++++
 tests/basics/async_with_return.py.exp |  15 ++++
 tests/run-tests                       |   2 +-
 6 files changed, 208 insertions(+), 38 deletions(-)
 create mode 100644 tests/basics/async_with_break.py
 create mode 100644 tests/basics/async_with_break.py.exp
 create mode 100644 tests/basics/async_with_return.py
 create mode 100644 tests/basics/async_with_return.py.exp

diff --git a/py/compile.c b/py/compile.c
index 7daf91103..98c09b210 100644
--- a/py/compile.c
+++ b/py/compile.c
@@ -1766,46 +1766,71 @@ STATIC void compile_async_with_stmt_helper(compiler_t *comp, int n, mp_parse_nod
         // no more pre-bits, compile the body of the with
         compile_node(comp, body);
     } else {
-        uint try_exception_label = comp_next_label(comp);
-        uint no_reraise_label = comp_next_label(comp);
-        uint try_else_label = comp_next_label(comp);
-        uint end_label = comp_next_label(comp);
-        qstr context;
+        uint l_finally_block = comp_next_label(comp);
+        uint l_aexit_no_exc = comp_next_label(comp);
+        uint l_ret_unwind_jump = comp_next_label(comp);
+        uint l_end = comp_next_label(comp);
 
         if (MP_PARSE_NODE_IS_STRUCT_KIND(nodes[0], PN_with_item)) {
             // this pre-bit is of the form "a as b"
             mp_parse_node_struct_t *pns = (mp_parse_node_struct_t*)nodes[0];
             compile_node(comp, pns->nodes[0]);
-            context = MP_PARSE_NODE_LEAF_ARG(pns->nodes[0]);
-            compile_store_id(comp, context);
-            compile_load_id(comp, context);
+            EMIT(dup_top);
             compile_await_object_method(comp, MP_QSTR___aenter__);
             c_assign(comp, pns->nodes[1], ASSIGN_STORE);
         } else {
             // this pre-bit is just an expression
             compile_node(comp, nodes[0]);
-            context = MP_PARSE_NODE_LEAF_ARG(nodes[0]);
-            compile_store_id(comp, context);
-            compile_load_id(comp, context);
+            EMIT(dup_top);
             compile_await_object_method(comp, MP_QSTR___aenter__);
             EMIT(pop_top);
         }
 
-        compile_load_id(comp, context);
-        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
+        // To keep the Python stack size down, and because we can't access values on
+        // this stack further down than 3 elements (via rot_three), we don't preload
+        // __aexit__ (as per normal with) but rather wait until we need it below.
 
-        EMIT_ARG(setup_block, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT);
+        // Start the try-finally statement
+        EMIT_ARG(setup_block, l_finally_block, MP_EMIT_SETUP_BLOCK_FINALLY);
         compile_increase_except_level(comp);
-        // compile additional pre-bits and the body
+
+        // Compile any additional pre-bits of the "async with", and also the body
+        EMIT_ARG(adjust_stack_size, 3); // stack adjust for possible UNWIND_JUMP state
         compile_async_with_stmt_helper(comp, n - 1, nodes + 1, body);
-        // finish this with block
+        EMIT_ARG(adjust_stack_size, -3);
+
+        // Finish the "try" block
         EMIT(pop_block);
-        EMIT_ARG(jump, try_else_label); // jump over exception handler
 
-        EMIT_ARG(label_assign, try_exception_label); // start of exception handler
-        EMIT(start_except_handler);
+        // At this point, after the with body has executed, we have 3 cases:
+        // 1. no exception, we just fall through to this point; stack: (..., ctx_mgr)
+        // 2. exception propagating out, we get to the finally block; stack: (..., ctx_mgr, exc)
+        // 3. return or unwind jump, we get to the finally block; stack: (..., ctx_mgr, X, INT)
+
+        // Handle case 1: call __aexit__
+        // Stack: (..., ctx_mgr)
+        EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // to tell end_finally there's no exception
+        EMIT(rot_two);
+        EMIT_ARG(jump, l_aexit_no_exc); // jump to code below to call __aexit__
+
+        // Start of "finally" block
+        // At this point we have case 2 or 3, we detect which one by the TOS being an exception or not
+        EMIT_ARG(label_assign, l_finally_block);
 
-        // at this point the stack contains: ..., __aexit__, self, exc
+        // Detect if TOS an exception or not
+        EMIT(dup_top);
+        EMIT_LOAD_GLOBAL(MP_QSTR_Exception);
+        EMIT_ARG(binary_op, MP_BINARY_OP_EXCEPTION_MATCH);
+        EMIT_ARG(pop_jump_if, false, l_ret_unwind_jump); // if not an exception then we have case 3
+
+        // Handle case 2: call __aexit__ and either swallow or re-raise the exception
+        // Stack: (..., ctx_mgr, exc)
+        EMIT(dup_top);
+        EMIT(rot_three);
+        EMIT(rot_two);
+        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
+        EMIT(rot_three);
+        EMIT(rot_three);
         EMIT(dup_top);
         #if MICROPY_CPYTHON_COMPAT
         EMIT_ARG(attr, MP_QSTR___class__, MP_EMIT_ATTR_LOAD); // get type(exc)
@@ -1816,32 +1841,38 @@ STATIC void compile_async_with_stmt_helper(compiler_t *comp, int n, mp_parse_nod
         #endif
         EMIT(rot_two);
         EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // dummy traceback value
-        // at this point the stack contains: ..., __aexit__, self, type(exc), exc, None
+        // Stack: (..., exc, __aexit__, ctx_mgr, type(exc), exc, None)
         EMIT_ARG(call_method, 3, 0, 0);
-
         compile_yield_from(comp);
-        EMIT_ARG(pop_jump_if, true, no_reraise_label);
-        EMIT_ARG(raise_varargs, 0);
-
-        EMIT_ARG(label_assign, no_reraise_label);
-        EMIT(pop_except);
-        EMIT_ARG(jump, end_label);
-
-        EMIT_ARG(adjust_stack_size, 3); // adjust for __aexit__, self, exc
-        compile_decrease_except_level(comp);
-        EMIT(end_finally);
-        EMIT(end_except_handler);
-
-        EMIT_ARG(label_assign, try_else_label); // start of try-else handler
+        EMIT_ARG(pop_jump_if, false, l_end);
+        EMIT(pop_top); // pop exception
+        EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // replace with None to swallow exception
+        EMIT_ARG(jump, l_end);
+        EMIT_ARG(adjust_stack_size, 2);
+
+        // Handle case 3: call __aexit__
+        // Stack: (..., ctx_mgr, X, INT)
+        EMIT_ARG(label_assign, l_ret_unwind_jump);
+        EMIT(rot_three);
+        EMIT(rot_three);
+        EMIT_ARG(label_assign, l_aexit_no_exc);
+        EMIT_ARG(load_method, MP_QSTR___aexit__, false);
         EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE);
         EMIT(dup_top);
         EMIT(dup_top);
         EMIT_ARG(call_method, 3, 0, 0);
         compile_yield_from(comp);
         EMIT(pop_top);
+        EMIT_ARG(adjust_stack_size, -1);
 
-        EMIT_ARG(label_assign, end_label);
-
+        // End of "finally" block
+        // Stack can have one of three configurations:
+        // a. (..., None) - from either case 1, or case 2 with swallowed exception
+        // b. (..., exc) - from case 2 with re-raised exception
+        // c. (..., X, INT) - from case 3
+        EMIT_ARG(label_assign, l_end);
+        compile_decrease_except_level(comp);
+        EMIT(end_finally);
     }
 }
 
diff --git a/tests/basics/async_with_break.py b/tests/basics/async_with_break.py
new file mode 100644
index 000000000..39bcbccb0
--- /dev/null
+++ b/tests/basics/async_with_break.py
@@ -0,0 +1,59 @@
+# test async with, escaped by a break
+
+class AContext:
+    async def __aenter__(self):
+        print('enter')
+        return 1
+    async def __aexit__(self, exc_type, exc, tb):
+        print('exit', exc_type, exc)
+
+async def f1():
+    while 1:
+        async with AContext():
+            print('body')
+            break
+            print('no 1')
+        print('no 2')
+
+o = f1()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
+
+async def f2():
+    while 1:
+        try:
+            async with AContext():
+                print('body')
+                break
+                print('no 1')
+        finally:
+            print('finally')
+        print('no 2')
+
+o = f2()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
+
+async def f3():
+    while 1:
+        try:
+            try:
+                async with AContext():
+                    print('body')
+                    break
+                    print('no 1')
+            finally:
+                print('finally inner')
+        finally:
+            print('finally outer')
+        print('no 2')
+
+o = f3()
+try:
+    print(o.send(None))
+except StopIteration:
+    print('finished')
diff --git a/tests/basics/async_with_break.py.exp b/tests/basics/async_with_break.py.exp
new file mode 100644
index 000000000..d077a88fa
--- /dev/null
+++ b/tests/basics/async_with_break.py.exp
@@ -0,0 +1,15 @@
+enter
+body
+exit None None
+finished
+enter
+body
+exit None None
+finally
+finished
+enter
+body
+exit None None
+finally inner
+finally outer
+finished
diff --git a/tests/basics/async_with_return.py b/tests/basics/async_with_return.py
new file mode 100644
index 000000000..9af88b839
--- /dev/null
+++ b/tests/basics/async_with_return.py
@@ -0,0 +1,50 @@
+# test async with, escaped by a return
+
+class AContext:
+    async def __aenter__(self):
+        print('enter')
+        return 1
+    async def __aexit__(self, exc_type, exc, tb):
+        print('exit', exc_type, exc)
+
+async def f1():
+    async with AContext():
+        print('body')
+        return
+
+o = f1()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
+
+async def f2():
+    try:
+        async with AContext():
+            print('body')
+            return
+    finally:
+        print('finally')
+
+o = f2()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
+
+async def f3():
+    try:
+        try:
+            async with AContext():
+                print('body')
+                return
+        finally:
+            print('finally inner')
+    finally:
+        print('finally outer')
+
+o = f3()
+try:
+    o.send(None)
+except StopIteration:
+    print('finished')
diff --git a/tests/basics/async_with_return.py.exp b/tests/basics/async_with_return.py.exp
new file mode 100644
index 000000000..d077a88fa
--- /dev/null
+++ b/tests/basics/async_with_return.py.exp
@@ -0,0 +1,15 @@
+enter
+body
+exit None None
+finished
+enter
+body
+exit None None
+finally
+finished
+enter
+body
+exit None None
+finally inner
+finally outer
+finished
diff --git a/tests/run-tests b/tests/run-tests
index cfd7c4037..e1b594edf 100755
--- a/tests/run-tests
+++ b/tests/run-tests
@@ -338,7 +338,7 @@ def run_tests(pyb, tests, args, base_path="."):
     if args.emit == 'native':
         skip_tests.update({'basics/%s.py' % t for t in 'gen_yield_from gen_yield_from_close gen_yield_from_ducktype gen_yield_from_exc gen_yield_from_executing gen_yield_from_iter gen_yield_from_send gen_yield_from_stopped gen_yield_from_throw gen_yield_from_throw2 gen_yield_from_throw3 generator1 generator2 generator_args generator_close generator_closure generator_exc generator_pend_throw generator_return generator_send'.split()}) # require yield
         skip_tests.update({'basics/%s.py' % t for t in 'bytes_gen class_store_class globals_del string_join'.split()}) # require yield
-        skip_tests.update({'basics/async_%s.py' % t for t in 'def await await2 for for2 with with2'.split()}) # require yield
+        skip_tests.update({'basics/async_%s.py' % t for t in 'def await await2 for for2 with with2 with_break with_return'.split()}) # require yield
         skip_tests.update({'basics/%s.py' % t for t in 'try_reraise try_reraise2'.split()}) # require raise_varargs
         skip_tests.update({'basics/%s.py' % t for t in 'with_break with_continue with_return'.split()}) # require complete with support
         skip_tests.add('basics/array_construct2.py') # requires generators
-- 
GitLab