From 5e5d69b35ed6c50286cf4b936e41b78df9a691e5 Mon Sep 17 00:00:00 2001
From: Paul Sokolovsky <pfalcon@users.sourceforge.net>
Date: Sun, 11 May 2014 21:13:01 +0300
Subject: [PATCH] objstr: Make .join() support bytes.

---
 py/objstr.c                 | 10 ++++++----
 tests/basics/string-join.py | 12 ++++++++++++
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/py/objstr.c b/py/objstr.c
index 33bfcc375..7549dedb7 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -357,7 +357,8 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
 }
 
 STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
-    assert(MP_OBJ_IS_STR(self_in));
+    assert(is_str_or_bytes(self_in));
+    const mp_obj_type_t *self_type = mp_obj_get_type(self_in);
 
     // get separation string
     GET_STR_DATA_LEN(self_in, sep_str, sep_len);
@@ -379,8 +380,9 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
     // count required length
     int required_len = 0;
     for (int i = 0; i < seq_len; i++) {
-        if (!MP_OBJ_IS_STR(seq_items[i])) {
-            nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "join expected a list of str's"));
+        if (mp_obj_get_type(seq_items[i]) != self_type) {
+            nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError,
+                "join expects a list of str/bytes objects consistent with self object"));
         }
         if (i > 0) {
             required_len += sep_len;
@@ -391,7 +393,7 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
 
     // make joined string
     byte *data;
-    mp_obj_t joined_str = mp_obj_str_builder_start(mp_obj_get_type(self_in), required_len, &data);
+    mp_obj_t joined_str = mp_obj_str_builder_start(self_type, required_len, &data);
     for (int i = 0; i < seq_len; i++) {
         if (i > 0) {
             memcpy(data, sep_str, sep_len);
diff --git a/tests/basics/string-join.py b/tests/basics/string-join.py
index 275a804c6..49bbfc5ca 100644
--- a/tests/basics/string-join.py
+++ b/tests/basics/string-join.py
@@ -10,3 +10,15 @@ print(''.join(''))
 print(''.join('abc'))
 print(','.join('abc'))
 print(','.join('abc' for i in range(5)))
+
+print(b','.join([b'abc', b'123']))
+
+try:
+    print(b','.join(['abc', b'123']))
+except TypeError:
+    print("TypeError")
+
+try:
+    print(','.join([b'abc', b'123']))
+except TypeError:
+    print("TypeError")
-- 
GitLab