From deed087e2c083821f22a849fdb4de62004bd010f Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Sun, 6 Apr 2014 11:11:15 +0100
Subject: [PATCH] py: str.split: handle non-default separator.

---
 py/objstr.c                  | 82 +++++++++++++++++++++++++-----------
 tests/basics/string_split.py | 21 +++++++++
 2 files changed, 78 insertions(+), 25 deletions(-)

diff --git a/py/objstr.c b/py/objstr.c
index 7000ed1fb..329dfe6dd 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -33,6 +33,7 @@ const mp_obj_t mp_const_empty_bytes;
 STATIC mp_obj_t mp_obj_new_str_iterator(mp_obj_t str);
 STATIC mp_obj_t mp_obj_new_bytes_iterator(mp_obj_t str);
 STATIC mp_obj_t str_new(const mp_obj_type_t *type, const byte* data, uint len);
+STATIC void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn));
 
 /******************************************************************************/
 /* str                                                                        */
@@ -367,38 +368,71 @@ bad_arg:
 #define is_ws(c) ((c) == ' ' || (c) == '\t')
 
 STATIC mp_obj_t str_split(uint n_args, const mp_obj_t *args) {
-    int splits = -1;
+    machine_int_t splits = -1;
     mp_obj_t sep = mp_const_none;
     if (n_args > 1) {
         sep = args[1];
         if (n_args > 2) {
-            splits = MP_OBJ_SMALL_INT_VALUE(args[2]);
+            splits = mp_obj_get_int(args[2]);
         }
     }
-    assert(sep == mp_const_none);
-    (void)sep; // unused; to hush compiler warning
+
     mp_obj_t res = mp_obj_new_list(0, NULL);
     GET_STR_DATA_LEN(args[0], s, len);
     const byte *top = s + len;
-    const byte *start;
-
-    // Initial whitespace is not counted as split, so we pre-do it
-    while (s < top && is_ws(*s)) s++;
-    while (s < top && splits != 0) {
-        start = s;
-        while (s < top && !is_ws(*s)) s++;
-        mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
-        if (s >= top) {
-            break;
-        }
+
+    if (sep == mp_const_none) {
+        // sep not given, so separate on whitespace
+
+        // Initial whitespace is not counted as split, so we pre-do it
         while (s < top && is_ws(*s)) s++;
-        if (splits > 0) {
-            splits--;
+        while (s < top && splits != 0) {
+            const byte *start = s;
+            while (s < top && !is_ws(*s)) s++;
+            mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
+            if (s >= top) {
+                break;
+            }
+            while (s < top && is_ws(*s)) s++;
+            if (splits > 0) {
+                splits--;
+            }
         }
-    }
 
-    if (s < top) {
-        mp_obj_list_append(res, mp_obj_new_str(s, top - s, false));
+        if (s < top) {
+            mp_obj_list_append(res, mp_obj_new_str(s, top - s, false));
+        }
+
+    } else {
+        // sep given
+
+        uint sep_len;
+        const char *sep_str = mp_obj_str_get_data(sep, &sep_len);
+
+        if (sep_len == 0) {
+            nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "empty separator"));
+        }
+
+        for (;;) {
+            const byte *start = s;
+            for (;;) {
+                if (splits == 0 || s + sep_len > top) {
+                    s = top;
+                    break;
+                } else if (memcmp(s, sep_str, sep_len) == 0) {
+                    break;
+                }
+                s++;
+            }
+            mp_obj_list_append(res, mp_obj_new_str(start, s - start, false));
+            if (s >= top) {
+                break;
+            }
+            s += sep_len;
+            if (splits > 0) {
+                splits--;
+            }
+        }
     }
 
     return res;
@@ -1052,7 +1086,7 @@ STATIC mp_obj_t str_modulo_format(mp_obj_t pattern, uint n_args, const mp_obj_t
                 }
                 pfenv_print_int(&pfenv_vstr, arg_as_int(arg), 1, 16, 'A', flags, fill, width);
                 break;
-            
+
             default:
                 nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError,
                     "unsupported format character '%c' (0x%x) at index %d",
@@ -1191,8 +1225,7 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) {
 STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, machine_int_t direction) {
     assert(MP_OBJ_IS_STR(self_in));
     if (!MP_OBJ_IS_STR(arg)) {
-        nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError,
-                                               "Can't convert '%s' object to str implicitly", mp_obj_get_type_str(arg)));
+        bad_implicit_conversion(arg);
     }
 
     GET_STR_DATA_LEN(self_in, str, str_len);
@@ -1365,8 +1398,7 @@ bool mp_obj_str_equal(mp_obj_t s1, mp_obj_t s2) {
     }
 }
 
-void bad_implicit_conversion(mp_obj_t self_in) __attribute__((noreturn));
-void bad_implicit_conversion(mp_obj_t self_in) {
+STATIC void bad_implicit_conversion(mp_obj_t self_in) {
     nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "Can't convert '%s' object to str implicitly", mp_obj_get_type_str(self_in)));
 }
 
diff --git a/tests/basics/string_split.py b/tests/basics/string_split.py
index f73cb4291..398a11539 100644
--- a/tests/basics/string_split.py
+++ b/tests/basics/string_split.py
@@ -1,3 +1,4 @@
+# default separator (whitespace)
 print("a b".split())
 print("   a   b    ".split(None))
 print("   a   b    ".split(None, 1))
@@ -5,3 +6,23 @@ print("   a   b    ".split(None, 2))
 print("   a   b  c  ".split(None, 1))
 print("   a   b  c  ".split(None, 0))
 print("   a   b  c  ".split(None, -1))
+
+# empty separator should fail
+try:
+    "abc".split('')
+except ValueError:
+    print("ValueError")
+
+# non-empty separator
+print("abc".split("a"))
+print("abc".split("b"))
+print("abc".split("c"))
+print("abc".split("z"))
+print("abc".split("ab"))
+print("abc".split("bc"))
+print("abc".split("abc"))
+print("abc".split("abcd"))
+print("abcabc".split("bc"))
+print("abcabc".split("bc", 0))
+print("abcabc".split("bc", 1))
+print("abcabc".split("bc", 2))
-- 
GitLab