diff --git a/py/compile.c b/py/compile.c
index 841b8f90c0fc833626beacb0bff5948c094d03b8..c84d23e943a90bd1be49a9dc217eacb1bfd3bdd8 100644
--- a/py/compile.c
+++ b/py/compile.c
@@ -662,6 +662,13 @@ STATIC void compile_funcdef_lambdef_param(compiler_t *comp, mp_parse_node_t pn)
 }
 
 STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_node_t pn_params, pn_kind_t pn_list_kind) {
+    // When we call compile_funcdef_lambdef_param below it can compile an arbitrary
+    // expression for default arguments, which may contain a lambda.  The lambda will
+    // call here in a nested way, so we must save and restore the relevant state.
+    bool orig_have_star = comp->have_star;
+    uint16_t orig_num_dict_params = comp->num_dict_params;
+    uint16_t orig_num_default_params = comp->num_default_params;
+
     // compile default parameters
     comp->have_star = false;
     comp->num_dict_params = 0;
@@ -681,6 +688,11 @@ STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_n
 
     // make the function
     close_over_variables_etc(comp, scope, comp->num_default_params, comp->num_dict_params);
+
+    // restore state
+    comp->have_star = orig_have_star;
+    comp->num_dict_params = orig_num_dict_params;
+    comp->num_default_params = orig_num_default_params;
 }
 
 // leaves function object on stack
diff --git a/tests/basics/fun_defargs.py b/tests/basics/fun_defargs.py
index ed25f5739d3d0f3c4e534e467738b4fc0d330535..1466c44094a46911eca42530407a0d4335f75464 100644
--- a/tests/basics/fun_defargs.py
+++ b/tests/basics/fun_defargs.py
@@ -1,3 +1,5 @@
+# testing default args to a function
+
 def fun1(val=5):
     print(val)
 
@@ -18,3 +20,10 @@ try:
     fun2(1, 2, 3, 4)
 except TypeError:
     print("TypeError")
+
+# lambda as default arg (exposes nested behaviour in compiler)
+def f(x=lambda:1):
+    return x()
+print(f())
+print(f(f))
+print(f(lambda:2))
diff --git a/tests/basics/fun_kwonly.py b/tests/basics/fun_kwonly.py
index bdff3a8210b4e1bb48012301610c93c0584e3044..7694c8ddcad87dd4515112c6aaa7cafb842390ac 100644
--- a/tests/basics/fun_kwonly.py
+++ b/tests/basics/fun_kwonly.py
@@ -57,3 +57,10 @@ def f(a, *b, c):
 f(1, c=2)
 f(1, 2, c=3)
 f(a=1, c=3)
+
+# lambda as kw-only arg (exposes nested behaviour in compiler)
+def f(*, x=lambda:1):
+    return x()
+print(f())
+print(f(x=f))
+print(f(x=lambda:2))