From c55a4d82cf03d22933028bd9db4031ef349fdc1f Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Wed, 24 Dec 2014 20:28:30 +0000
Subject: [PATCH] py: Make bytes objs work with more str methods; add tests.

---
 py/objstr.c                     | 44 +++++++++++++++++-------------
 tests/basics/bytes_count.py     | 48 +++++++++++++++++++++++++++++++++
 tests/basics/bytes_find.py      | 23 ++++++++++++++++
 tests/basics/bytes_partition.py | 29 ++++++++++++++++++++
 tests/basics/bytes_replace.py   | 13 +++++++++
 tests/basics/bytes_split.py     | 28 +++++++++++++++++++
 tests/basics/bytes_strip.py     | 16 +++++++++++
 7 files changed, 183 insertions(+), 18 deletions(-)
 create mode 100644 tests/basics/bytes_count.py
 create mode 100644 tests/basics/bytes_find.py
 create mode 100644 tests/basics/bytes_partition.py
 create mode 100644 tests/basics/bytes_replace.py
 create mode 100644 tests/basics/bytes_split.py
 create mode 100644 tests/basics/bytes_strip.py

diff --git a/py/objstr.c b/py/objstr.c
index 7cd44471e..16ee2ea1d 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -46,7 +46,6 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, mp_uint_t n_args, const mp_o
 mp_obj_t mp_obj_new_str_iterator(mp_obj_t str);
 STATIC mp_obj_t mp_obj_new_bytes_iterator(mp_obj_t str);
 STATIC NORETURN void bad_implicit_conversion(mp_obj_t self_in);
-STATIC NORETURN void arg_type_mixup(void);
 
 /******************************************************************************/
 /* str                                                                        */
@@ -525,7 +524,7 @@ STATIC mp_obj_t str_split(mp_uint_t n_args, const mp_obj_t *args) {
     } else {
         // sep given
         if (mp_obj_get_type(sep) != self_type) {
-            arg_type_mixup();
+            bad_implicit_conversion(sep);
         }
 
         mp_uint_t sep_len;
@@ -627,7 +626,7 @@ STATIC mp_obj_t str_finder(mp_uint_t n_args, const mp_obj_t *args, mp_int_t dire
     assert(MP_OBJ_IS_STR_OR_BYTES(args[0]));
 
     // check argument type
-    if (!MP_OBJ_IS_STR(args[1])) {
+    if (mp_obj_get_type(args[1]) != self_type) {
         bad_implicit_conversion(args[1]);
     }
 
@@ -720,7 +719,7 @@ STATIC mp_obj_t str_uni_strip(int type, mp_uint_t n_args, const mp_obj_t *args)
         chars_to_del_len = sizeof(whitespace);
     } else {
         if (mp_obj_get_type(args[1]) != self_type) {
-            arg_type_mixup();
+            bad_implicit_conversion(args[1]);
         }
         GET_STR_DATA_LEN(args[1], s, l);
         chars_to_del = s;
@@ -759,7 +758,11 @@ STATIC mp_obj_t str_uni_strip(int type, mp_uint_t n_args, const mp_obj_t *args)
 
     if (!first_good_char_pos_set) {
         // string is all whitespace, return ''
-        return MP_OBJ_NEW_QSTR(MP_QSTR_);
+        if (self_type == &mp_type_str) {
+            return MP_OBJ_NEW_QSTR(MP_QSTR_);
+        } else {
+            return mp_const_empty_bytes;
+        }
     }
 
     assert(last_good_char_pos >= first_good_char_pos);
@@ -1470,11 +1473,13 @@ STATIC mp_obj_t str_replace(mp_uint_t n_args, const mp_obj_t *args) {
 
     // check argument types
 
-    if (!MP_OBJ_IS_STR(args[1])) {
+    const mp_obj_type_t *self_type = mp_obj_get_type(args[0]);
+
+    if (mp_obj_get_type(args[1]) != self_type) {
         bad_implicit_conversion(args[1]);
     }
 
-    if (!MP_OBJ_IS_STR(args[2])) {
+    if (mp_obj_get_type(args[2]) != self_type) {
         bad_implicit_conversion(args[2]);
     }
 
@@ -1543,7 +1548,7 @@ STATIC mp_obj_t str_replace(mp_uint_t n_args, const mp_obj_t *args) {
                 return args[0];
             } else {
                 // substr found, allocate new string
-                replaced_str = mp_obj_str_builder_start(mp_obj_get_type(args[0]), replaced_str_index, &data);
+                replaced_str = mp_obj_str_builder_start(self_type, replaced_str_index, &data);
                 assert(data != NULL);
             }
         } else {
@@ -1561,7 +1566,7 @@ STATIC mp_obj_t str_count(mp_uint_t n_args, const mp_obj_t *args) {
     assert(MP_OBJ_IS_STR_OR_BYTES(args[0]));
 
     // check argument type
-    if (!MP_OBJ_IS_STR(args[1])) {
+    if (mp_obj_get_type(args[1]) != self_type) {
         bad_implicit_conversion(args[1]);
     }
 
@@ -1597,12 +1602,10 @@ STATIC mp_obj_t str_count(mp_uint_t n_args, const mp_obj_t *args) {
 }
 
 STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, mp_int_t direction) {
-    if (!MP_OBJ_IS_STR_OR_BYTES(self_in)) {
-        assert(0);
-    }
+    assert(MP_OBJ_IS_STR_OR_BYTES(self_in));
     mp_obj_type_t *self_type = mp_obj_get_type(self_in);
     if (self_type != mp_obj_get_type(arg)) {
-        arg_type_mixup();
+        bad_implicit_conversion(arg);
     }
 
     GET_STR_DATA_LEN(self_in, str, str_len);
@@ -1612,7 +1615,16 @@ STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, mp_int_t directi
         nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "empty separator"));
     }
 
-    mp_obj_t result[] = {MP_OBJ_NEW_QSTR(MP_QSTR_), MP_OBJ_NEW_QSTR(MP_QSTR_), MP_OBJ_NEW_QSTR(MP_QSTR_)};
+    mp_obj_t result[3];
+    if (self_type == &mp_type_str) {
+        result[0] = MP_OBJ_NEW_QSTR(MP_QSTR_);
+        result[1] = MP_OBJ_NEW_QSTR(MP_QSTR_);
+        result[2] = MP_OBJ_NEW_QSTR(MP_QSTR_);
+    } else {
+        result[0] = mp_const_empty_bytes;
+        result[1] = mp_const_empty_bytes;
+        result[2] = mp_const_empty_bytes;
+    }
 
     if (direction > 0) {
         result[0] = self_in;
@@ -1953,10 +1965,6 @@ STATIC void bad_implicit_conversion(mp_obj_t self_in) {
     }
 }
 
-STATIC void arg_type_mixup(void) {
-    nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "Can't mix str and bytes arguments"));
-}
-
 mp_uint_t mp_obj_str_get_hash(mp_obj_t self_in) {
     // TODO: This has too big overhead for hash accessor
     if (MP_OBJ_IS_STR_OR_BYTES(self_in)) {
diff --git a/tests/basics/bytes_count.py b/tests/basics/bytes_count.py
new file mode 100644
index 000000000..95bcfe310
--- /dev/null
+++ b/tests/basics/bytes_count.py
@@ -0,0 +1,48 @@
+print(b"".count(b""))
+print(b"".count(b"a"))
+print(b"a".count(b""))
+print(b"a".count(b"a"))
+print(b"a".count(b"b"))
+print(b"b".count(b"a"))
+
+print(b"aaa".count(b""))
+print(b"aaa".count(b"a"))
+print(b"aaa".count(b"aa"))
+print(b"aaa".count(b"aaa"))
+print(b"aaa".count(b"aaaa"))
+
+print(b"aaaa".count(b""))
+print(b"aaaa".count(b"a"))
+print(b"aaaa".count(b"aa"))
+print(b"aaaa".count(b"aaa"))
+print(b"aaaa".count(b"aaaa"))
+print(b"aaaa".count(b"aaaaa"))
+
+print(b"aaa".count(b"", 1))
+print(b"aaa".count(b"", 2))
+print(b"aaa".count(b"", 3))
+
+print(b"aaa".count(b"", 1, 2))
+
+print(b"asdfasdfaaa".count(b"asdf", -100))
+print(b"asdfasdfaaa".count(b"asdf", -8))
+print(b"asdf".count(b's', True))
+print(b"asdf".count(b'a', True))
+print(b"asdf".count(b'a', False))
+print(b"asdf".count(b'a', 1 == 2))
+print(b"hello world".count(b'l'))
+print(b"hello world".count(b'l', 5))
+print(b"hello world".count(b'l', 3))
+print(b"hello world".count(b'z', 3, 6))
+print(b"aaaa".count(b'a'))
+print(b"aaaa".count(b'a', 0, 3))
+print(b"aaaa".count(b'a', 0, 4))
+print(b"aaaa".count(b'a', 0, 5))
+print(b"aaaa".count(b'a', 1, 5))
+print(b"aaaa".count(b'a', -1, 5))
+print(b"abbabba".count(b"abba"))
+
+def t():
+    return True
+
+print(b"0000".count(b'0', t()))
diff --git a/tests/basics/bytes_find.py b/tests/basics/bytes_find.py
new file mode 100644
index 000000000..434669a90
--- /dev/null
+++ b/tests/basics/bytes_find.py
@@ -0,0 +1,23 @@
+print(b"hello world".find(b"ll"))
+print(b"hello world".find(b"ll", None))
+print(b"hello world".find(b"ll", 1))
+print(b"hello world".find(b"ll", 1, None))
+print(b"hello world".find(b"ll", None, None))
+print(b"hello world".find(b"ll", 1, -1))
+print(b"hello world".find(b"ll", 1, 1))
+print(b"hello world".find(b"ll", 1, 2))
+print(b"hello world".find(b"ll", 1, 3))
+print(b"hello world".find(b"ll", 1, 4))
+print(b"hello world".find(b"ll", 1, 5))
+print(b"hello world".find(b"ll", -100))
+print(b"0000".find(b'0'))
+print(b"0000".find(b'0', 0))
+print(b"0000".find(b'0', 1))
+print(b"0000".find(b'0', 2))
+print(b"0000".find(b'0', 3))
+print(b"0000".find(b'0', 4))
+print(b"0000".find(b'0', 5))
+print(b"0000".find(b'-1', 3))
+print(b"0000".find(b'1', 3))
+print(b"0000".find(b'1', 4))
+print(b"0000".find(b'1', 5))
diff --git a/tests/basics/bytes_partition.py b/tests/basics/bytes_partition.py
new file mode 100644
index 000000000..3868a81a5
--- /dev/null
+++ b/tests/basics/bytes_partition.py
@@ -0,0 +1,29 @@
+print(b"asdf".partition(b'g'))
+print(b"asdf".partition(b'a'))
+print(b"asdf".partition(b's'))
+print(b"asdf".partition(b'f'))
+print(b"asdf".partition(b'd'))
+print(b"asdf".partition(b'asd'))
+print(b"asdf".partition(b'sdf'))
+print(b"asdf".partition(b'as'))
+print(b"asdf".partition(b'df'))
+print(b"asdf".partition(b'asdf'))
+print(b"asdf".partition(b'asdfa'))
+print(b"asdf".partition(b'fasdf'))
+print(b"asdf".partition(b'fasdfa'))
+print(b"abba".partition(b'a'))
+print(b"abba".partition(b'b'))
+
+try:
+    print(b"asdf".partition(1))
+except TypeError:
+    print("Raised TypeError")
+else:
+    print("Did not raise TypeError")
+
+try:
+    print(b"asdf".partition(b''))
+except ValueError:
+    print("Raised ValueError")
+else:
+    print("Did not raise ValueError")
diff --git a/tests/basics/bytes_replace.py b/tests/basics/bytes_replace.py
new file mode 100644
index 000000000..24f03e61c
--- /dev/null
+++ b/tests/basics/bytes_replace.py
@@ -0,0 +1,13 @@
+print(b"".replace(b"a", b"b"))
+print(b"aaa".replace(b"a", b"b", 0))
+print(b"aaa".replace(b"a", b"b", -5))
+print(b"asdfasdf".replace(b"a", b"b"))
+print(b"aabbaabbaabbaa".replace(b"aa", b"cc", 3))
+print(b"a".replace(b"aa", b"bb"))
+print(b"testingtesting".replace(b"ing", b""))
+print(b"testINGtesting".replace(b"ing", b"ING!"))
+
+print(b"".replace(b"", b"1"))
+print(b"A".replace(b"", b"1"))
+print(b"AB".replace(b"", b"1"))
+print(b"AB".replace(b"", b"12"))
diff --git a/tests/basics/bytes_split.py b/tests/basics/bytes_split.py
new file mode 100644
index 000000000..a9dda1ee8
--- /dev/null
+++ b/tests/basics/bytes_split.py
@@ -0,0 +1,28 @@
+# default separator (whitespace)
+print(b"a b".split())
+print(b"   a   b    ".split(None))
+print(b"   a   b    ".split(None, 1))
+print(b"   a   b    ".split(None, 2))
+print(b"   a   b  c  ".split(None, 1))
+print(b"   a   b  c  ".split(None, 0))
+print(b"   a   b  c  ".split(None, -1))
+
+# empty separator should fail
+try:
+    b"abc".split(b'')
+except ValueError:
+    print("ValueError")
+
+# non-empty separator
+print(b"abc".split(b"a"))
+print(b"abc".split(b"b"))
+print(b"abc".split(b"c"))
+print(b"abc".split(b"z"))
+print(b"abc".split(b"ab"))
+print(b"abc".split(b"bc"))
+print(b"abc".split(b"abc"))
+print(b"abc".split(b"abcd"))
+print(b"abcabc".split(b"bc"))
+print(b"abcabc".split(b"bc", 0))
+print(b"abcabc".split(b"bc", 1))
+print(b"abcabc".split(b"bc", 2))
diff --git a/tests/basics/bytes_strip.py b/tests/basics/bytes_strip.py
new file mode 100644
index 000000000..71e4ac185
--- /dev/null
+++ b/tests/basics/bytes_strip.py
@@ -0,0 +1,16 @@
+print(b"".strip())
+print(b" \t\n\r\v\f".strip())
+print(b" T E S T".strip())
+print(b"abcabc".strip(b"ce"))
+print(b"aaa".strip(b"b"))
+print(b"abc  efg ".strip(b"g a"))
+
+print(b'   spacious   '.lstrip())
+print(b'www.example.com'.lstrip(b'cmowz.'))
+
+print(b'   spacious   '.rstrip())
+print(b'mississippi'.rstrip(b'ipz'))
+
+# Test that stripping unstrippable string returns original object
+s = b"abc"
+print(id(s.strip()) == id(s))
-- 
GitLab