diff --git a/py/builtin.c b/py/builtin.c
index 1924e6080fc6fc6098d7b10c202236d40f4b40ab..4fea1fdb2105279d825cc8573fc6aedfb66da8b3 100644
--- a/py/builtin.c
+++ b/py/builtin.c
@@ -284,63 +284,50 @@ STATIC mp_obj_t mp_builtin_iter(mp_obj_t o_in) {
 
 MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_iter_obj, mp_builtin_iter);
 
-STATIC mp_obj_t mp_builtin_max(uint n_args, const mp_obj_t *args) {
+STATIC mp_obj_t mp_builtin_min_max(uint n_args, const mp_obj_t *args, mp_map_t *kwargs, int op) {
+    mp_map_elem_t *key_elem = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_key), MP_MAP_LOOKUP);
+    mp_obj_t key_fn = key_elem == NULL ? MP_OBJ_NULL : key_elem->value;
     if (n_args == 1) {
         // given an iterable
         mp_obj_t iterable = mp_getiter(args[0]);
-        mp_obj_t max_obj = NULL;
+        mp_obj_t best_key = MP_OBJ_NULL;
+        mp_obj_t best_obj = MP_OBJ_NULL;
         mp_obj_t item;
         while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
-            if (max_obj == NULL || (mp_binary_op(MP_BINARY_OP_LESS, max_obj, item) == mp_const_true)) {
-                max_obj = item;
+            mp_obj_t key = key_fn == MP_OBJ_NULL ? item : mp_call_function_1(key_fn, item);
+            if (best_obj == MP_OBJ_NULL || (mp_binary_op(op, key, best_key) == mp_const_true)) {
+                best_key = key;
+                best_obj = item;
             }
         }
-        if (max_obj == NULL) {
-            nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "max() arg is an empty sequence"));
+        if (best_obj == MP_OBJ_NULL) {
+            nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "arg is an empty sequence"));
         }
-        return max_obj;
+        return best_obj;
     } else {
         // given many args
-        mp_obj_t max_obj = args[0];
-        for (int i = 1; i < n_args; i++) {
-            if (mp_binary_op(MP_BINARY_OP_LESS, max_obj, args[i]) == mp_const_true) {
-                max_obj = args[i];
+        mp_obj_t best_key = MP_OBJ_NULL;
+        mp_obj_t best_obj = MP_OBJ_NULL;
+        for (mp_uint_t i = 0; i < n_args; i++) {
+            mp_obj_t key = key_fn == MP_OBJ_NULL ? args[i] : mp_call_function_1(key_fn, args[i]);
+            if (best_obj == MP_OBJ_NULL || (mp_binary_op(op, key, best_key) == mp_const_true)) {
+                best_key = key;
+                best_obj = args[i];
             }
         }
-        return max_obj;
+        return best_obj;
     }
 }
 
-MP_DEFINE_CONST_FUN_OBJ_VAR(mp_builtin_max_obj, 1, mp_builtin_max);
-
-STATIC mp_obj_t mp_builtin_min(uint n_args, const mp_obj_t *args) {
-    if (n_args == 1) {
-        // given an iterable
-        mp_obj_t iterable = mp_getiter(args[0]);
-        mp_obj_t min_obj = NULL;
-        mp_obj_t item;
-        while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
-            if (min_obj == NULL || (mp_binary_op(MP_BINARY_OP_LESS, item, min_obj) == mp_const_true)) {
-                min_obj = item;
-            }
-        }
-        if (min_obj == NULL) {
-            nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "min() arg is an empty sequence"));
-        }
-        return min_obj;
-    } else {
-        // given many args
-        mp_obj_t min_obj = args[0];
-        for (int i = 1; i < n_args; i++) {
-            if (mp_binary_op(MP_BINARY_OP_LESS, args[i], min_obj) == mp_const_true) {
-                min_obj = args[i];
-            }
-        }
-        return min_obj;
-    }
+STATIC mp_obj_t mp_builtin_max(uint n_args, const mp_obj_t *args, mp_map_t *kwargs) {
+    return mp_builtin_min_max(n_args, args, kwargs, MP_BINARY_OP_MORE);
 }
+MP_DEFINE_CONST_FUN_OBJ_KW(mp_builtin_max_obj, 1, mp_builtin_max);
 
-MP_DEFINE_CONST_FUN_OBJ_VAR(mp_builtin_min_obj, 1, mp_builtin_min);
+STATIC mp_obj_t mp_builtin_min(uint n_args, const mp_obj_t *args, mp_map_t *kwargs) {
+    return mp_builtin_min_max(n_args, args, kwargs, MP_BINARY_OP_LESS);
+}
+MP_DEFINE_CONST_FUN_OBJ_KW(mp_builtin_min_obj, 1, mp_builtin_min);
 
 STATIC mp_obj_t mp_builtin_next(mp_obj_t o) {
     mp_obj_t ret = mp_iternext_allow_raise(o);
diff --git a/tests/basics/builtin_minmax.py b/tests/basics/builtin_minmax.py
index 8ee4bbca7d4e06558d6443a0d99dca1460da9c5e..a5f035b9094e2b787a6de2ce4c92d459e51abe92 100644
--- a/tests/basics/builtin_minmax.py
+++ b/tests/basics/builtin_minmax.py
@@ -13,3 +13,13 @@ print(max(-1,0))
 print(min([1,2,4,0,-1,2]))
 print(max([1,2,4,0,-1,2]))
 
+# test with key function
+lst = [2, 1, 3, 4]
+print(min(lst, key=lambda x:x))
+print(min(lst, key=lambda x:-x))
+print(min(1, 2, 3, 4, key=lambda x:-x))
+print(min(4, 3, 2, 1, key=lambda x:-x))
+print(max(lst, key=lambda x:x))
+print(max(lst, key=lambda x:-x))
+print(max(1, 2, 3, 4, key=lambda x:-x))
+print(max(4, 3, 2, 1, key=lambda x:-x))