From c1bef21920d7fa03484647f2c339f53663fe0180 Mon Sep 17 00:00:00 2001
From: "John R. Lenton" <jlenton@gmail.com>
Date: Sat, 11 Jan 2014 12:39:33 +0000
Subject: [PATCH] Implemented support for `in` and `not in` operators.

---
 py/objdict.c                      | 16 +++++++++++
 py/objset.c                       | 20 ++++++++++++++
 py/objstr.c                       |  9 +++++++
 py/runtime.c                      | 45 +++++++++++++++++++++++++++----
 tests/basics/run-tests            |  3 +++
 tests/basics/tests/containment.py | 23 ++++++++++++++++
 6 files changed, 111 insertions(+), 5 deletions(-)
 create mode 100644 tests/basics/tests/containment.py

diff --git a/py/objdict.c b/py/objdict.c
index 8902e1020..5f8a04d05 100644
--- a/py/objdict.c
+++ b/py/objdict.c
@@ -57,6 +57,12 @@ static mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
                 return elem->value;
             }
         }
+        case RT_COMPARE_OP_IN:
+        case RT_COMPARE_OP_NOT_IN:
+        {
+            mp_map_elem_t *elem = mp_map_lookup(&o->map, rhs_in, MP_MAP_LOOKUP);
+            return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (elem == NULL));
+        }
         default:
             // op not supported
             return NULL;
@@ -339,10 +345,20 @@ static void dict_view_print(void (*print)(void *env, const char *fmt, ...), void
     print(env, "])");
 }
 
+static mp_obj_t dict_view_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
+    /* only supported for the 'keys' kind until sets and dicts are refactored */
+    mp_obj_dict_view_t *o = lhs_in;
+    if (o->kind != MP_DICT_VIEW_KEYS) return NULL;
+    if (op != RT_COMPARE_OP_IN && op != RT_COMPARE_OP_NOT_IN) return NULL;
+    return dict_binary_op(op, o->dict, rhs_in);
+}
+
+
 static const mp_obj_type_t dict_view_type = {
     { &mp_const_type },
     "dict_view",
     .print = dict_view_print,
+    .binary_op = dict_view_binary_op,
     .getiter = dict_view_getiter,
 };
 
diff --git a/py/objset.c b/py/objset.c
index 67dab11df..71ed55335 100644
--- a/py/objset.c
+++ b/py/objset.c
@@ -8,6 +8,7 @@
 #include "mpqstr.h"
 #include "obj.h"
 #include "runtime.h"
+#include "runtime0.h"
 #include "map.h"
 
 typedef struct _mp_obj_set_t {
@@ -31,6 +32,24 @@ void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj
     print(env, "}");
 }
 
+
+static mp_obj_t set_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
+    mp_obj_set_t *o = lhs_in;
+    switch (op) {
+        case RT_COMPARE_OP_IN:
+        case RT_COMPARE_OP_NOT_IN:
+        {
+            mp_obj_t elem = mp_set_lookup(&o->set, rhs_in, false);
+            return ((op == RT_COMPARE_OP_IN) ^ (elem == NULL))
+                ? mp_const_true : mp_const_false;
+        }
+        default:
+            // op not supported
+            return NULL;
+    }
+}
+
+
 static mp_obj_t set_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) {
     switch (n_args) {
         case 0:
@@ -58,6 +77,7 @@ const mp_obj_type_t set_type = {
     { &mp_const_type },
     "set",
     .print = set_print,
+    .binary_op = set_binary_op,
     .make_new = set_make_new,
 };
 
diff --git a/py/objstr.c b/py/objstr.c
index ea4f5ead2..eb8b4c4be 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -85,6 +85,15 @@ mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
                 return mp_obj_new_str(qstr_from_str_take(val, alloc_len));
             }
             break;
+        case RT_COMPARE_OP_IN:
+        case RT_COMPARE_OP_NOT_IN:
+            /* NOTE `a in b` is `b.__contains__(a)` */
+            if (MP_OBJ_IS_TYPE(rhs_in, &str_type)) {
+                const char *rhs_str = qstr_str(((mp_obj_str_t*)rhs_in)->qstr);
+                /* FIXME \0 in strs */
+                return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (strstr(lhs_str, rhs_str) == NULL));
+            }
+            break;
     }
 
     return MP_OBJ_NULL; // op not supported
diff --git a/py/runtime.c b/py/runtime.c
index 53861f1e4..53aea4cbb 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -558,22 +558,57 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
                 }
             }
         }
+    }
 
-        if (MP_OBJ_IS_OBJ(lhs)) {
-            mp_obj_base_t *o = lhs;
+    /* deal with `in` and `not in`
+     *
+     * NOTE `a in b` is `b.__contains__(a)`, hence why the generic dispatch
+     * needs to go below
+     */
+    if (op == RT_COMPARE_OP_IN || op == RT_COMPARE_OP_NOT_IN) {
+        if (!MP_OBJ_IS_SMALL_INT(rhs)) {
+            mp_obj_base_t *o = rhs;
             if (o->type->binary_op != NULL) {
-                mp_obj_t result = o->type->binary_op(op, lhs, rhs);
-                if (result != NULL) {
-                    return result;
+                mp_obj_t res = o->type->binary_op(op, rhs, lhs);
+                if (res != NULL) {
+                    return res;
+                }
+            }
+            if (o->type->getiter != NULL) {
+                /* second attempt, walk the iterator */
+                mp_obj_t next = NULL;
+                mp_obj_t iter = rt_getiter(rhs);
+                while ((next = rt_iternext(iter)) != mp_const_stop_iteration) {
+                    if (mp_obj_equal(next, lhs)) {
+                        return MP_BOOL(op == RT_COMPARE_OP_IN);
+                    }
                 }
+                return MP_BOOL(op != RT_COMPARE_OP_IN);
+            }
+        }
+
+        nlr_jump(mp_obj_new_exception_msg_varg(
+                     MP_QSTR_TypeError, "'%s' object is not iterable",
+                     mp_obj_get_type_str(rhs)));
+        return mp_const_none;
+    }
+
+    if (MP_OBJ_IS_OBJ(lhs)) {
+        mp_obj_base_t *o = lhs;
+        if (o->type->binary_op != NULL) {
+            mp_obj_t result = o->type->binary_op(op, lhs, rhs);
+            if (result != NULL) {
+                return result;
             }
         }
+        // TODO implement dispatch for reverse binary ops
     }
 
     // TODO specify in error message what the operator is
     nlr_jump(mp_obj_new_exception_msg_varg(MP_QSTR_TypeError,
         "unsupported operand types for binary operator: '%s', '%s'",
         mp_obj_get_type_str(lhs), mp_obj_get_type_str(rhs)));
+    return mp_const_none;
 }
 
 mp_obj_t rt_make_function_from_id(int unique_code_id) {
diff --git a/tests/basics/run-tests b/tests/basics/run-tests
index 0c3995da1..bc2969ae3 100755
--- a/tests/basics/run-tests
+++ b/tests/basics/run-tests
@@ -42,4 +42,7 @@ echo "$numpassed tests passed"
 if [[ $numfailed != 0 ]]
 then
     echo "$numfailed tests failed -$namefailed"
+    exit 1
+else
+    exit 0
 fi
diff --git a/tests/basics/tests/containment.py b/tests/basics/tests/containment.py
new file mode 100644
index 000000000..84d40b4e8
--- /dev/null
+++ b/tests/basics/tests/containment.py
@@ -0,0 +1,23 @@
+for i in 1, 2:
+    for o in {1:2}, {1}, {1:2}.keys():
+        print("{} in {}: {}".format(i, o, i in o))
+        print("{} not in {}: {}".format(i, o, i not in o))
+
+haystack = "supercalifragilistc"
+for needle in (haystack[i:] for i in range(len(haystack))):
+    print(needle, "in", haystack, "::", needle in haystack)
+    print(needle, "not in", haystack, "::", needle not in haystack)
+    print(haystack, "in", needle, "::", haystack in needle)
+    print(haystack, "not in", needle, "::", haystack not in needle)
+for needle in (haystack[:i+1] for i in range(len(haystack))):
+    print(needle, "in", haystack, "::", needle in haystack)
+    print(needle, "not in", haystack, "::", needle not in haystack)
+    print(haystack, "in", needle, "::", haystack in needle)
+    print(haystack, "not in", needle, "::", haystack not in needle)
+
+# until here, the tests would work without the 'second attempt' iteration thing.
+
+for i in 1, 2:
+    for o in [], [1], [1, 2]:
+        print("{} in {}: {}".format(i, o, i in o))
+        print("{} not in {}: {}".format(i, o, i not in o))
-- 
GitLab