From a65c03c6c07fa8b092f7cd8d02c81e9ef8cd4a50 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Wed, 5 Nov 2014 16:30:34 +0000
Subject: [PATCH] py: Allow +, in, and compare ops between bytes and
 bytearray/array.

Eg b"123" + bytearray(2) now works.  This patch actually decreases code
size while adding functionality: 32-bit unix down by 128 bytes, stmhal
down by 84 bytes.
---
 py/objstr.c                    | 164 +++++++++++++++++----------------
 tests/basics/bytes_add.py      |   9 ++
 tests/basics/bytes_compare2.py |  11 ++-
 3 files changed, 103 insertions(+), 81 deletions(-)
 create mode 100644 tests/basics/bytes_add.py

diff --git a/py/objstr.c b/py/objstr.c
index 6a3a77c53..d6cfa8be1 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -285,77 +285,99 @@ STATIC const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byt
 // works because both those types use it as their binary_op method.  Revisit
 // MP_OBJ_IS_STR_OR_BYTES if this fact changes.
 mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
-    GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
+    // check for modulo
+    if (op == MP_BINARY_OP_MODULO) {
+        mp_obj_t *args;
+        mp_uint_t n_args;
+        mp_obj_t dict = MP_OBJ_NULL;
+        if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple)) {
+            // TODO: Support tuple subclasses?
+            mp_obj_tuple_get(rhs_in, &n_args, &args);
+        } else if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_dict)) {
+            args = NULL;
+            n_args = 0;
+            dict = rhs_in;
+        } else {
+            args = &rhs_in;
+            n_args = 1;
+        }
+        return str_modulo_format(lhs_in, n_args, args, dict);
+    }
+
+    // from now on we need lhs type and data, so extract them
     mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
-    mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in);
-    switch (op) {
-        case MP_BINARY_OP_ADD:
-        case MP_BINARY_OP_INPLACE_ADD:
-            if (lhs_type == rhs_type) {
-                // add 2 strings or bytes
-
-                GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
-                mp_uint_t alloc_len = lhs_len + rhs_len;
-
-                /* code for making qstr
-                byte *q_ptr;
-                byte *val = qstr_build_start(alloc_len, &q_ptr);
-                memcpy(val, lhs_data, lhs_len);
-                memcpy(val + lhs_len, rhs_data, rhs_len);
-                return MP_OBJ_NEW_QSTR(qstr_build_end(q_ptr));
-                */
-
-                // code for non-qstr
-                byte *data;
-                mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
-                memcpy(data, lhs_data, lhs_len);
-                memcpy(data + lhs_len, rhs_data, rhs_len);
-                return mp_obj_str_builder_end(s);
-            }
-            break;
+    GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
 
-        case MP_BINARY_OP_IN:
-            /* NOTE `a in b` is `b.__contains__(a)` */
-            if (lhs_type == rhs_type) {
-                GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
-                return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
+    // check for multiply
+    if (op == MP_BINARY_OP_MULTIPLY) {
+        mp_int_t n;
+        if (!mp_obj_get_int_maybe(rhs_in, &n)) {
+            return MP_OBJ_NULL; // op not supported
+        }
+        if (n <= 0) {
+            if (lhs_type == &mp_type_str) {
+                return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
+            } else {
+                return mp_const_empty_bytes;
             }
-            break;
+        }
+        byte *data;
+        mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
+        mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
+        return mp_obj_str_builder_end(s);
+    }
+
+    // From now on all operations allow:
+    //    - str with str
+    //    - bytes with bytes
+    //    - bytes with bytearray
+    //    - bytes with array.array
+    // To do this efficiently we use the buffer protocol to extract the raw
+    // data for the rhs, but only if the lhs is a bytes object.
+    //
+    // NOTE: CPython does not allow comparison between bytes ard array.array
+    // (even if the array is of type 'b'), even though it allows addition of
+    // such types.  We are not compatible with this (we do allow comparison
+    // of bytes with anything that has the buffer protocol).  It would be
+    // easy to "fix" this with a bit of extra logic below, but it costs code
+    // size and execution time so we don't.
+
+    const byte *rhs_data;
+    mp_uint_t rhs_len;
+    if (lhs_type == mp_obj_get_type(rhs_in)) {
+        GET_STR_DATA_LEN(rhs_in, rhs_data_, rhs_len_);
+        rhs_data = rhs_data_;
+        rhs_len = rhs_len_;
+    } else if (lhs_type == &mp_type_bytes) {
+        mp_buffer_info_t bufinfo;
+        if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
+            goto incompatible;
+        }
+        rhs_data = bufinfo.buf;
+        rhs_len = bufinfo.len;
+    } else {
+        // incompatible types
+    incompatible:
+        if (op == MP_BINARY_OP_EQUAL) {
+            return mp_const_false; // can check for equality against every type
+        }
+        return MP_OBJ_NULL; // op not supported
+    }
 
-        case MP_BINARY_OP_MULTIPLY: {
-            mp_int_t n;
-            if (!mp_obj_get_int_maybe(rhs_in, &n)) {
-                return MP_OBJ_NULL; // op not supported
-            }
-            if (n <= 0) {
-                if (lhs_type == &mp_type_str) {
-                    return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
-                }
-                n = 0;
-            }
+    switch (op) {
+        case MP_BINARY_OP_ADD:
+        case MP_BINARY_OP_INPLACE_ADD: {
+            mp_uint_t alloc_len = lhs_len + rhs_len;
             byte *data;
-            mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
-            mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
+            mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
+            memcpy(data, lhs_data, lhs_len);
+            memcpy(data + lhs_len, rhs_data, rhs_len);
             return mp_obj_str_builder_end(s);
         }
 
-        case MP_BINARY_OP_MODULO: {
-            mp_obj_t *args;
-            mp_uint_t n_args;
-            mp_obj_t dict = MP_OBJ_NULL;
-            if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple)) {
-                // TODO: Support tuple subclasses?
-                mp_obj_tuple_get(rhs_in, &n_args, &args);
-            } else if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_dict)) {
-                args = NULL;
-                n_args = 0;
-                dict = rhs_in;
-            } else {
-                args = &rhs_in;
-                n_args = 1;
-            }
-            return str_modulo_format(lhs_in, n_args, args, dict);
-        }
+        case MP_BINARY_OP_IN:
+            /* NOTE `a in b` is `b.__contains__(a)` */
+            return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
 
         //case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
         case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
@@ -363,21 +385,7 @@ mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
         case MP_BINARY_OP_LESS_EQUAL:
         case MP_BINARY_OP_MORE:
         case MP_BINARY_OP_MORE_EQUAL:
-            if (lhs_type == rhs_type) {
-                GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
-                return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
-            }
-            if (lhs_type == &mp_type_bytes) {
-                mp_buffer_info_t bufinfo;
-                if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
-                    goto uncomparable;
-                }
-                return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, bufinfo.buf, bufinfo.len));
-            }
-uncomparable:
-            if (op == MP_BINARY_OP_EQUAL) {
-                return mp_const_false;
-            }
+            return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
     }
 
     return MP_OBJ_NULL; // op not supported
diff --git a/tests/basics/bytes_add.py b/tests/basics/bytes_add.py
new file mode 100644
index 000000000..1288d5ac3
--- /dev/null
+++ b/tests/basics/bytes_add.py
@@ -0,0 +1,9 @@
+# test bytes + other
+
+print(b"123" + b"456")
+print(b"123" + bytearray(2))
+
+import array
+
+print(b"123" + array.array('i', [1]))
+print(b"\x01\x02" + array.array('b', [1, 2]))
diff --git a/tests/basics/bytes_compare2.py b/tests/basics/bytes_compare2.py
index 769d76b11..02516de93 100644
--- a/tests/basics/bytes_compare2.py
+++ b/tests/basics/bytes_compare2.py
@@ -1,7 +1,12 @@
-import array
-
 print(b"1" == 1)
 print(b"123" == bytearray(b"123"))
 print(b"123" == "123")
-# CPyhon gives False here
+print(b'123' < bytearray(b"124"))
+print(b'123' > bytearray(b"122"))
+print(bytearray(b"23") in b"1234")
+
+import array
+
+print(array.array('b', [1, 2]) in b'\x01\x02\x03')
+# CPython gives False here
 #print(b"\x01\x02\x03" == array.array("B", [1, 2, 3]))
-- 
GitLab