From c1bef21920d7fa03484647f2c339f53663fe0180 Mon Sep 17 00:00:00 2001 From: "John R. Lenton" <jlenton@gmail.com> Date: Sat, 11 Jan 2014 12:39:33 +0000 Subject: [PATCH] Implemented support for `in` and `not in` operators. --- py/objdict.c | 16 +++++++++++ py/objset.c | 20 ++++++++++++++ py/objstr.c | 9 +++++++ py/runtime.c | 45 +++++++++++++++++++++++++++---- tests/basics/run-tests | 3 +++ tests/basics/tests/containment.py | 23 ++++++++++++++++ 6 files changed, 111 insertions(+), 5 deletions(-) create mode 100644 tests/basics/tests/containment.py diff --git a/py/objdict.c b/py/objdict.c index 8902e1020..5f8a04d05 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -57,6 +57,12 @@ static mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { return elem->value; } } + case RT_COMPARE_OP_IN: + case RT_COMPARE_OP_NOT_IN: + { + mp_map_elem_t *elem = mp_map_lookup(&o->map, rhs_in, MP_MAP_LOOKUP); + return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (elem == NULL)); + } default: // op not supported return NULL; @@ -339,10 +345,20 @@ static void dict_view_print(void (*print)(void *env, const char *fmt, ...), void print(env, "])"); } +static mp_obj_t dict_view_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { + /* only supported for the 'keys' kind until sets and dicts are refactored */ + mp_obj_dict_view_t *o = lhs_in; + if (o->kind != MP_DICT_VIEW_KEYS) return NULL; + if (op != RT_COMPARE_OP_IN && op != RT_COMPARE_OP_NOT_IN) return NULL; + return dict_binary_op(op, o->dict, rhs_in); +} + + static const mp_obj_type_t dict_view_type = { { &mp_const_type }, "dict_view", .print = dict_view_print, + .binary_op = dict_view_binary_op, .getiter = dict_view_getiter, }; diff --git a/py/objset.c b/py/objset.c index 67dab11df..71ed55335 100644 --- a/py/objset.c +++ b/py/objset.c @@ -8,6 +8,7 @@ #include "mpqstr.h" #include "obj.h" #include "runtime.h" +#include "runtime0.h" #include "map.h" typedef struct _mp_obj_set_t { @@ -31,6 +32,24 @@ void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj print(env, "}"); } + +static mp_obj_t set_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { + mp_obj_set_t *o = lhs_in; + switch (op) { + case RT_COMPARE_OP_IN: + case RT_COMPARE_OP_NOT_IN: + { + mp_obj_t elem = mp_set_lookup(&o->set, rhs_in, false); + return ((op == RT_COMPARE_OP_IN) ^ (elem == NULL)) + ? mp_const_true : mp_const_false; + } + default: + // op not supported + return NULL; + } +} + + static mp_obj_t set_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) { switch (n_args) { case 0: @@ -58,6 +77,7 @@ const mp_obj_type_t set_type = { { &mp_const_type }, "set", .print = set_print, + .binary_op = set_binary_op, .make_new = set_make_new, }; diff --git a/py/objstr.c b/py/objstr.c index ea4f5ead2..eb8b4c4be 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -85,6 +85,15 @@ mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { return mp_obj_new_str(qstr_from_str_take(val, alloc_len)); } break; + case RT_COMPARE_OP_IN: + case RT_COMPARE_OP_NOT_IN: + /* NOTE `a in b` is `b.__contains__(a)` */ + if (MP_OBJ_IS_TYPE(rhs_in, &str_type)) { + const char *rhs_str = qstr_str(((mp_obj_str_t*)rhs_in)->qstr); + /* FIXME \0 in strs */ + return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (strstr(lhs_str, rhs_str) == NULL)); + } + break; } return MP_OBJ_NULL; // op not supported diff --git a/py/runtime.c b/py/runtime.c index 53861f1e4..53aea4cbb 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -558,22 +558,57 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { } } } + } - if (MP_OBJ_IS_OBJ(lhs)) { - mp_obj_base_t *o = lhs; + /* deal with `in` and `not in` + * + * NOTE `a in b` is `b.__contains__(a)`, hence why the generic dispatch + * needs to go below + */ + if (op == RT_COMPARE_OP_IN || op == RT_COMPARE_OP_NOT_IN) { + if (!MP_OBJ_IS_SMALL_INT(rhs)) { + mp_obj_base_t *o = rhs; if (o->type->binary_op != NULL) { - mp_obj_t result = o->type->binary_op(op, lhs, rhs); - if (result != NULL) { - return result; + mp_obj_t res = o->type->binary_op(op, rhs, lhs); + if (res != NULL) { + return res; + } + } + if (o->type->getiter != NULL) { + /* second attempt, walk the iterator */ + mp_obj_t next = NULL; + mp_obj_t iter = rt_getiter(rhs); + while ((next = rt_iternext(iter)) != mp_const_stop_iteration) { + if (mp_obj_equal(next, lhs)) { + return MP_BOOL(op == RT_COMPARE_OP_IN); + } } + return MP_BOOL(op != RT_COMPARE_OP_IN); + } + } + + nlr_jump(mp_obj_new_exception_msg_varg( + MP_QSTR_TypeError, "'%s' object is not iterable", + mp_obj_get_type_str(rhs))); + return mp_const_none; + } + + if (MP_OBJ_IS_OBJ(lhs)) { + mp_obj_base_t *o = lhs; + if (o->type->binary_op != NULL) { + mp_obj_t result = o->type->binary_op(op, lhs, rhs); + if (result != NULL) { + return result; } } + // TODO implement dispatch for reverse binary ops } // TODO specify in error message what the operator is nlr_jump(mp_obj_new_exception_msg_varg(MP_QSTR_TypeError, "unsupported operand types for binary operator: '%s', '%s'", mp_obj_get_type_str(lhs), mp_obj_get_type_str(rhs))); + return mp_const_none; } mp_obj_t rt_make_function_from_id(int unique_code_id) { diff --git a/tests/basics/run-tests b/tests/basics/run-tests index 0c3995da1..bc2969ae3 100755 --- a/tests/basics/run-tests +++ b/tests/basics/run-tests @@ -42,4 +42,7 @@ echo "$numpassed tests passed" if [[ $numfailed != 0 ]] then echo "$numfailed tests failed -$namefailed" + exit 1 +else + exit 0 fi diff --git a/tests/basics/tests/containment.py b/tests/basics/tests/containment.py new file mode 100644 index 000000000..84d40b4e8 --- /dev/null +++ b/tests/basics/tests/containment.py @@ -0,0 +1,23 @@ +for i in 1, 2: + for o in {1:2}, {1}, {1:2}.keys(): + print("{} in {}: {}".format(i, o, i in o)) + print("{} not in {}: {}".format(i, o, i not in o)) + +haystack = "supercalifragilistc" +for needle in (haystack[i:] for i in range(len(haystack))): + print(needle, "in", haystack, "::", needle in haystack) + print(needle, "not in", haystack, "::", needle not in haystack) + print(haystack, "in", needle, "::", haystack in needle) + print(haystack, "not in", needle, "::", haystack not in needle) +for needle in (haystack[:i+1] for i in range(len(haystack))): + print(needle, "in", haystack, "::", needle in haystack) + print(needle, "not in", haystack, "::", needle not in haystack) + print(haystack, "in", needle, "::", haystack in needle) + print(haystack, "not in", needle, "::", haystack not in needle) + +# until here, the tests would work without the 'second attempt' iteration thing. + +for i in 1, 2: + for o in [], [1], [1, 2]: + print("{} in {}: {}".format(i, o, i in o)) + print("{} not in {}: {}".format(i, o, i not in o)) -- GitLab