diff --git a/py/objstr.c b/py/objstr.c
index 247cfde6d5dd7166ee53812eb4e5a870bdc5e131..c44e9ebf16e3ee47ded23e4efea57f5318f96e24 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -540,7 +540,8 @@ enum { LSTRIP, RSTRIP, STRIP };
 
 STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
     assert(1 <= n_args && n_args <= 2);
-    assert(MP_OBJ_IS_STR(args[0]));
+    assert(is_str_or_bytes(args[0]));
+    const mp_obj_type_t *self_type = mp_obj_get_type(args[0]);
 
     const byte *chars_to_del;
     uint chars_to_del_len;
@@ -550,7 +551,9 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
         chars_to_del = whitespace;
         chars_to_del_len = sizeof(whitespace);
     } else {
-        assert(MP_OBJ_IS_STR(args[1]));
+        if (mp_obj_get_type(args[1]) != self_type) {
+            arg_type_mixup();
+        }
         GET_STR_DATA_LEN(args[1], s, l);
         chars_to_del = s;
         chars_to_del_len = l;
@@ -594,7 +597,7 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
     assert(last_good_char_pos >= first_good_char_pos);
     //+1 to accomodate the last character
     machine_uint_t stripped_len = last_good_char_pos - first_good_char_pos + 1;
-    return mp_obj_new_str(orig_str + first_good_char_pos, stripped_len, false);
+    return str_new(self_type, orig_str + first_good_char_pos, stripped_len);
 }
 
 STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) {
diff --git a/tests/basics/string_strip.py b/tests/basics/string_strip.py
index 8e03eff93af221e0ab6e305016ee41562b19d6b4..4684c2a2480d9502f7d35fa1cee762a0759ea12e 100644
--- a/tests/basics/string_strip.py
+++ b/tests/basics/string_strip.py
@@ -10,3 +10,13 @@ print('www.example.com'.lstrip('cmowz.'))
 
 print('   spacious   '.rstrip())
 print('mississippi'.rstrip('ipz'))
+
+print(b'mississippi'.rstrip(b'ipz'))
+try:
+    print(b'mississippi'.rstrip('ipz'))
+except TypeError:
+    print("TypeError")
+try:
+    print('mississippi'.rstrip(b'ipz'))
+except TypeError:
+    print("TypeError")