diff --git a/py/vm.c b/py/vm.c
index 2611be683a3fc723d3605ce88a8a6d17945398cb..393b8a1db7de8ecb26c579bdd69241d352205b2b 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -637,10 +637,14 @@ unwind_jump:;
                         unum -= 1;
                         assert(exc_sp >= exc_stack);
                         if (MP_TAGPTR_TAG1(exc_sp->val_sp)) {
+                            // Getting here the stack looks like:
+                            //     (..., X, dest_ip)
+                            // where X is pointed to by exc_sp->val_sp and in the case
+                            // of a "with" block contains the context manager info.
                             // We're going to run "finally" code as a coroutine
                             // (not calling it recursively). Set up a sentinel
                             // on a stack so it can return back to us when it is
-                            // done (when END_FINALLY reached).
+                            // done (when WITH_CLEANUP or END_FINALLY reached).
                             PUSH((void*)unum); // push number of exception handlers left to unwind
                             PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_JUMP)); // push sentinel
                             ip = exc_sp->handler; // get exception handler byte code address
@@ -1016,15 +1020,24 @@ unwind_jump:;
 unwind_return:
                     while (exc_sp >= exc_stack) {
                         if (MP_TAGPTR_TAG1(exc_sp->val_sp)) {
+                            // Getting here the stack looks like:
+                            //     (..., X, [iter0, iter1, ...,] ret_val)
+                            // where X is pointed to by exc_sp->val_sp and in the case
+                            // of a "with" block contains the context manager info.
+                            // There may be 0 or more for-iterators between X and the
+                            // return value, and these must be removed before control can
+                            // pass to the finally code.  We simply copy the ret_value down
+                            // over these iterators, if they exist.  If they don't then the
+                            // following is a null operation.
+                            mp_obj_t *finally_sp = MP_TAGPTR_PTR(exc_sp->val_sp);
+                            finally_sp[1] = sp[0];
+                            sp = &finally_sp[1];
                             // We're going to run "finally" code as a coroutine
                             // (not calling it recursively). Set up a sentinel
                             // on a stack so it can return back to us when it is
-                            // done (when END_FINALLY reached).
+                            // done (when WITH_CLEANUP or END_FINALLY reached).
                             PUSH(MP_OBJ_NEW_SMALL_INT(UNWIND_RETURN));
                             ip = exc_sp->handler;
-                            // We don't need to do anything with sp, finally is just
-                            // syntactic sugar for sequential execution??
-                            // sp =
                             exc_sp--;
                             goto dispatch_loop;
                         }
diff --git a/tests/basics/try_finally_return.py b/tests/basics/try_finally_return.py
index 4adf3f0977efdfb0301b43ca54fe4c5e0f53ce87..31a507e8d0753a85b8d56c2087a6dbcd34b045ac 100644
--- a/tests/basics/try_finally_return.py
+++ b/tests/basics/try_finally_return.py
@@ -21,3 +21,52 @@ def func3():
         print("finally 3")
 
 print(func3())
+
+# for loop within try-finally
+def f():
+    try:
+        for i in [1, 2]:
+            return i
+    finally:
+        print('finally')
+print(f())
+
+# multiple for loops within try-finally
+def f():
+    try:
+        for i in [1, 2]:
+            for j in [3, 4]:
+                return (i, j)
+    finally:
+        print('finally')
+print(f())
+
+# multiple for loops and nested try-finally's
+def f():
+    try:
+        for i in [1, 2]:
+            for j in [3, 4]:
+                try:
+                    for k in [5, 6]:
+                        for l in [7, 8]:
+                            return (i, j, k, l)
+                finally:
+                    print('finally 2')
+    finally:
+        print('finally 1')
+print(f())
+
+# multiple for loops that are optimised, and nested try-finally's
+def f():
+    try:
+        for i in range(1, 3):
+            for j in range(3, 5):
+                try:
+                    for k in range(5, 7):
+                        for l in range(7, 9):
+                            return (i, j, k, l)
+                finally:
+                    print('finally 2')
+    finally:
+        print('finally 1')
+print(f())
diff --git a/tests/basics/with_return.py b/tests/basics/with_return.py
index cb0135c8b3aa6d012d9814169d349613e4aeff76..fd848f13313600bbeadde59ee2e864fb6f8028c4 100644
--- a/tests/basics/with_return.py
+++ b/tests/basics/with_return.py
@@ -1,14 +1,53 @@
 class CtxMgr:
+    def __init__(self, id):
+        self.id = id
 
     def __enter__(self):
-        print("__enter__")
+        print("__enter__", self.id)
         return self
 
     def __exit__(self, a, b, c):
-        print("__exit__", repr(a), repr(b))
+        print("__exit__", self.id, repr(a), repr(b))
 
+# simple case
 def foo():
-    with CtxMgr():
+    with CtxMgr(1):
         return 4
-
 print(foo())
+
+# for loop within with (iterator needs removing upon return)
+def f():
+    with CtxMgr(1):
+        for i in [1, 2]:
+            return i
+print(f())
+
+# multiple for loops within with
+def f():
+    with CtxMgr(1):
+        for i in [1, 2]:
+            for j in [3, 4]:
+                return (i, j)
+print(f())
+
+# multiple for loops within nested withs
+def f():
+    with CtxMgr(1):
+        for i in [1, 2]:
+            for j in [3, 4]:
+                with CtxMgr(2):
+                    for k in [5, 6]:
+                        for l in [7, 8]:
+                            return (i, j, k, l)
+print(f())
+
+# multiple for loops that are optimised, and nested withs
+def f():
+    with CtxMgr(1):
+        for i in range(1, 3):
+            for j in range(3, 5):
+                with CtxMgr(2):
+                    for k in range(5, 7):
+                        for l in range(7, 9):
+                            return (i, j, k, l)
+print(f())