From 4ebb32fb952c02eb554311cbbd0acac7e858570b Mon Sep 17 00:00:00 2001
From: Damien <damien.p.george@gmail.com>
Date: Sat, 2 Nov 2013 14:33:10 +0000
Subject: [PATCH] Implement: str.join, more float support, ROT_TWO in VM.

---
 py/runtime.c | 174 +++++++++++++++++++++++++++++++++++++++++----------
 py/runtime.h |   7 ++-
 py/vm.c      |   9 ++-
 3 files changed, 154 insertions(+), 36 deletions(-)

diff --git a/py/runtime.c b/py/runtime.c
index 2d8fa0206d..b506c6a549 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -165,15 +165,27 @@ struct _py_obj_base_t {
     };
 };
 
+static qstr q_append;
+static qstr q_join;
+static qstr q_print;
+static qstr q_len;
+static qstr q___build_class__;
+static qstr q___next__;
+static qstr q_AttributeError;
+static qstr q_IndexError;
+static qstr q_KeyError;
+static qstr q_NameError;
+static qstr q_TypeError;
+
 py_obj_t py_const_none;
 py_obj_t py_const_false;
 py_obj_t py_const_true;
 py_obj_t py_const_stop_iteration;
 
 // locals and globals need to be pointers because they can be the same in outer module scope
-py_map_t *map_locals;
-py_map_t *map_globals;
-py_map_t map_builtins;
+static py_map_t *map_locals;
+static py_map_t *map_globals;
+static py_map_t map_builtins;
 
 // approximatelly doubling primes; made with Mathematica command: Table[Prime[Floor[(1.7)^n]], {n, 3, 24}]
 static int doubling_primes[] = {7, 19, 43, 89, 179, 347, 647, 1229, 2297, 4243, 7829, 14347, 26017, 47149, 84947, 152443, 273253, 488399, 869927, 1547173, 2745121, 4861607};
@@ -319,7 +331,7 @@ static bool fit_small_int(py_small_int_t o) {
     return true;
 }
 
-py_obj_t py_obj_new_int(int value) {
+py_obj_t py_obj_new_int(machine_int_t value) {
     return TO_SMALL_INT(value);
 }
 
@@ -400,6 +412,39 @@ py_obj_t py_obj_new_list_iterator(py_obj_base_t *list, int cur) {
     return o;
 }
 
+py_obj_t rt_str_join(py_obj_t self_in, py_obj_t arg) {
+    assert(IS_O(self_in, O_STR));
+    py_obj_base_t *self = self_in;
+    int required_len = strlen(qstr_str(self->u_str));
+
+    // process arg, count required chars
+    if (!IS_O(arg, O_TUPLE) && !IS_O(arg, O_LIST)) {
+        goto bad_arg;
+    }
+    py_obj_base_t *tuple_list = arg;
+    for (int i = 0; i < tuple_list->u_tuple_list.len; i++) {
+        if (!IS_O(tuple_list->u_tuple_list.items[i], O_STR)) {
+            goto bad_arg;
+        }
+        required_len += strlen(qstr_str(((py_obj_base_t*)tuple_list->u_tuple_list.items[i])->u_str));
+    }
+
+    // make joined string
+    char *joined_str = m_new(char, required_len + 1);
+    joined_str[0] = 0;
+    for (int i = 0; i < tuple_list->u_tuple_list.len; i++) {
+        const char *s2 = qstr_str(((py_obj_base_t*)tuple_list->u_tuple_list.items[i])->u_str);
+        if (i > 0) {
+            strcat(joined_str, qstr_str(self->u_str));
+        }
+        strcat(joined_str, s2);
+    }
+    return py_obj_new_str(qstr_from_str_take(joined_str));
+
+bad_arg:
+    nlr_jump(py_obj_new_exception_2(q_TypeError, "?str.join expecting a list of str's", NULL, NULL));
+}
+
 py_obj_t rt_list_append(py_obj_t self_in, py_obj_t arg) {
     assert(IS_O(self_in, O_LIST));
     py_obj_base_t *self = self_in;
@@ -420,17 +465,6 @@ py_obj_t rt_gen_instance_next(py_obj_t self_in) {
     }
 }
 
-static qstr q_append;
-static qstr q_print;
-static qstr q_len;
-static qstr q___build_class__;
-static qstr q___next__;
-static qstr q_AttributeError;
-static qstr q_IndexError;
-static qstr q_KeyError;
-static qstr q_NameError;
-static qstr q_TypeError;
-
 typedef enum {
     PY_CODE_NONE,
     PY_CODE_BYTE,
@@ -461,6 +495,7 @@ typedef struct _py_code_t {
 static int next_unique_code_id;
 static py_code_t *unique_codes;
 
+py_obj_t fun_str_join;
 py_obj_t fun_list_append;
 py_obj_t fun_gen_instance_next;
 
@@ -527,6 +562,7 @@ FILE *fp_native = NULL;
 
 void rt_init(void) {
     q_append = qstr_from_str_static("append");
+    q_join = qstr_from_str_static("join");
     q_print = qstr_from_str_static("print");
     q_len = qstr_from_str_static("len");
     q___build_class__ = qstr_from_str_static("__build_class__");
@@ -556,6 +592,7 @@ void rt_init(void) {
     next_unique_code_id = 2; // 1 is reserved for the __main__ module scope
     unique_codes = NULL;
 
+    fun_str_join = rt_make_function_2(rt_str_join);
     fun_list_append = rt_make_function_2(rt_list_append);
     fun_gen_instance_next = rt_make_function_1(rt_gen_instance_next);
 
@@ -849,13 +886,28 @@ int rt_is_true(py_obj_t arg) {
     }
 }
 
-int py_get_int(py_obj_t arg) {
+machine_int_t py_get_int(py_obj_t arg) {
+    if (arg == py_const_false) {
+        return 0;
+    } else if (arg == py_const_true) {
+        return 1;
+    } else if (IS_SMALL_INT(arg)) {
+        return FROM_SMALL_INT(arg);
+    } else {
+        assert(0);
+        return 0;
+    }
+}
+
+machine_float_t py_obj_get_float(py_obj_t arg) {
     if (arg == py_const_false) {
         return 0;
     } else if (arg == py_const_true) {
         return 1;
     } else if (IS_SMALL_INT(arg)) {
         return FROM_SMALL_INT(arg);
+    } else if (IS_O(arg, O_FLOAT)) {
+        return ((py_obj_base_t*)arg)->u_flt;
     } else {
         assert(0);
         return 0;
@@ -871,7 +923,7 @@ qstr py_get_qstr(py_obj_t arg) {
     }
 }
 
-py_obj_t *py_get_array_fixed_n(py_obj_t o_in, int n) {
+py_obj_t *py_get_array_fixed_n(py_obj_t o_in, machine_int_t n) {
     if (IS_O(o_in, O_TUPLE) || IS_O(o_in, O_LIST)) {
         py_obj_base_t *o = o_in;
         if (o->u_tuple_list.len != n) {
@@ -979,35 +1031,66 @@ py_obj_t rt_binary_op(int op, py_obj_t lhs, py_obj_t rhs) {
             assert(0);
         }
     } else if (IS_SMALL_INT(lhs) && IS_SMALL_INT(rhs)) {
+        py_small_int_t lhs_val = FROM_SMALL_INT(lhs);
+        py_small_int_t rhs_val = FROM_SMALL_INT(rhs);
         py_small_int_t val;
         switch (op) {
             case RT_BINARY_OP_OR:
-            case RT_BINARY_OP_INPLACE_OR: val = FROM_SMALL_INT(lhs) | FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_OR: val = lhs_val | rhs_val; break;
             case RT_BINARY_OP_XOR:
-            case RT_BINARY_OP_INPLACE_XOR: val = FROM_SMALL_INT(lhs) ^ FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_XOR: val = lhs_val ^ rhs_val; break;
             case RT_BINARY_OP_AND:
-            case RT_BINARY_OP_INPLACE_AND: val = FROM_SMALL_INT(lhs) & FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_AND: val = lhs_val & rhs_val; break;
             case RT_BINARY_OP_LSHIFT:
-            case RT_BINARY_OP_INPLACE_LSHIFT: val = FROM_SMALL_INT(lhs) << FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_LSHIFT: val = lhs_val << rhs_val; break;
             case RT_BINARY_OP_RSHIFT:
-            case RT_BINARY_OP_INPLACE_RSHIFT: val = FROM_SMALL_INT(lhs) >> FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_RSHIFT: val = lhs_val >> rhs_val; break;
             case RT_BINARY_OP_ADD:
-            case RT_BINARY_OP_INPLACE_ADD: val = FROM_SMALL_INT(lhs) + FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_ADD: val = lhs_val + rhs_val; break;
             case RT_BINARY_OP_SUBTRACT:
-            case RT_BINARY_OP_INPLACE_SUBTRACT: val = FROM_SMALL_INT(lhs) - FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_SUBTRACT: val = lhs_val - rhs_val; break;
             case RT_BINARY_OP_MULTIPLY:
-            case RT_BINARY_OP_INPLACE_MULTIPLY: val = FROM_SMALL_INT(lhs) * FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_MULTIPLY: val = lhs_val * rhs_val; break;
             case RT_BINARY_OP_FLOOR_DIVIDE:
-            case RT_BINARY_OP_INPLACE_FLOOR_DIVIDE: val = FROM_SMALL_INT(lhs) / FROM_SMALL_INT(rhs); break;
+            case RT_BINARY_OP_INPLACE_FLOOR_DIVIDE: val = lhs_val / rhs_val; break;
 #if MICROPY_ENABLE_FLOAT
             case RT_BINARY_OP_TRUE_DIVIDE:
-            case RT_BINARY_OP_INPLACE_TRUE_DIVIDE: return py_obj_new_float((float_t)FROM_SMALL_INT(lhs) / (float_t)FROM_SMALL_INT(rhs));
+            case RT_BINARY_OP_INPLACE_TRUE_DIVIDE: return py_obj_new_float((float_t)lhs_val / (float_t)rhs_val);
 #endif
+            case RT_BINARY_OP_POWER:
+            case RT_BINARY_OP_INPLACE_POWER:
+                // TODO
+                if (rhs_val == 2) {
+                    val = lhs_val * lhs_val;
+                    break;
+                }
             default: printf("%d\n", op); assert(0); val = 0;
         }
         if (fit_small_int(val)) {
             return TO_SMALL_INT(val);
         }
+#if MICROPY_ENABLE_FLOAT
+    } else if (IS_O(lhs, O_FLOAT) || IS_O(rhs, O_FLOAT)) {
+        float_t lhs_val = py_obj_get_float(lhs);
+        float_t rhs_val = py_obj_get_float(rhs);
+        float_t val;
+        switch (op) {
+            case RT_BINARY_OP_ADD:
+            case RT_BINARY_OP_INPLACE_ADD: val = lhs_val + rhs_val; break;
+            case RT_BINARY_OP_SUBTRACT:
+            case RT_BINARY_OP_INPLACE_SUBTRACT: val = lhs_val - rhs_val; break;
+            case RT_BINARY_OP_MULTIPLY:
+            case RT_BINARY_OP_INPLACE_MULTIPLY: val = lhs_val * rhs_val; break;
+            /* TODO floor(?) the value
+            case RT_BINARY_OP_FLOOR_DIVIDE:
+            case RT_BINARY_OP_INPLACE_FLOOR_DIVIDE: val = lhs_val / rhs_val; break;
+            */
+            case RT_BINARY_OP_TRUE_DIVIDE:
+            case RT_BINARY_OP_INPLACE_TRUE_DIVIDE: val = lhs_val / rhs_val; break;
+            default: printf("%d\n", op); assert(0); val = 0;
+        }
+        return py_obj_new_float(val);
+#endif
     } else if (IS_O(lhs, O_STR) && IS_O(rhs, O_STR)) {
         const char *lhs_str = qstr_str(((py_obj_base_t*)lhs)->u_str);
         const char *rhs_str = qstr_str(((py_obj_base_t*)rhs)->u_str);
@@ -1045,12 +1128,34 @@ py_obj_t rt_compare_op(int op, py_obj_t lhs, py_obj_t rhs) {
 
     // deal with small ints
     if (IS_SMALL_INT(lhs) && IS_SMALL_INT(rhs)) {
+        py_small_int_t lhs_val = FROM_SMALL_INT(lhs);
+        py_small_int_t rhs_val = FROM_SMALL_INT(rhs);
+        int cmp;
+        switch (op) {
+            case RT_COMPARE_OP_LESS: cmp = lhs_val < rhs_val; break;
+            case RT_COMPARE_OP_MORE: cmp = lhs_val > rhs_val; break;
+            case RT_COMPARE_OP_LESS_EQUAL: cmp = lhs_val <= rhs_val; break;
+            case RT_COMPARE_OP_MORE_EQUAL: cmp = lhs_val >= rhs_val; break;
+            default: assert(0); cmp = 0;
+        }
+        if (cmp) {
+            return py_const_true;
+        } else {
+            return py_const_false;
+        }
+    }
+
+#if MICROPY_ENABLE_FLOAT
+    // deal with floats
+    if (IS_O(lhs, O_FLOAT) || IS_O(rhs, O_FLOAT)) {
+        float_t lhs_val = py_obj_get_float(lhs);
+        float_t rhs_val = py_obj_get_float(rhs);
         int cmp;
         switch (op) {
-            case RT_COMPARE_OP_LESS: cmp = FROM_SMALL_INT(lhs) < FROM_SMALL_INT(rhs); break;
-            case RT_COMPARE_OP_MORE: cmp = FROM_SMALL_INT(lhs) > FROM_SMALL_INT(rhs); break;
-            case RT_COMPARE_OP_LESS_EQUAL: cmp = FROM_SMALL_INT(lhs) <= FROM_SMALL_INT(rhs); break;
-            case RT_COMPARE_OP_MORE_EQUAL: cmp = FROM_SMALL_INT(lhs) >= FROM_SMALL_INT(rhs); break;
+            case RT_COMPARE_OP_LESS: cmp = lhs_val < rhs_val; break;
+            case RT_COMPARE_OP_MORE: cmp = lhs_val > rhs_val; break;
+            case RT_COMPARE_OP_LESS_EQUAL: cmp = lhs_val <= rhs_val; break;
+            case RT_COMPARE_OP_MORE_EQUAL: cmp = lhs_val >= rhs_val; break;
             default: assert(0); cmp = 0;
         }
         if (cmp) {
@@ -1059,6 +1164,7 @@ py_obj_t rt_compare_op(int op, py_obj_t lhs, py_obj_t rhs) {
             return py_const_false;
         }
     }
+#endif
 
     // not implemented
     assert(0);
@@ -1482,7 +1588,11 @@ no_attr:
 
 void rt_load_method(py_obj_t base, qstr attr, py_obj_t *dest) {
     DEBUG_OP_printf("load method %s\n", qstr_str(attr));
-    if (IS_O(base, O_GEN_INSTANCE) && attr == q___next__) {
+    if (IS_O(base, O_STR) && attr == q_join) {
+        dest[1] = fun_str_join;
+        dest[0] = base;
+        return;
+    } else if (IS_O(base, O_GEN_INSTANCE) && attr == q___next__) {
         dest[1] = fun_gen_instance_next;
         dest[0] = base;
         return;
diff --git a/py/runtime.h b/py/runtime.h
index 33d7bed266..f8c4972b09 100644
--- a/py/runtime.h
+++ b/py/runtime.h
@@ -97,10 +97,11 @@ void rt_assign_native_code(int unique_code_id, py_fun_t f, uint len, int n_args)
 void rt_assign_inline_asm_code(int unique_code_id, py_fun_t f, uint len, int n_args);
 void py_obj_print(py_obj_t o);
 int rt_is_true(py_obj_t arg);
-int py_get_int(py_obj_t arg);
+machine_int_t py_get_int(py_obj_t arg);
+machine_float_t py_obj_get_float(py_obj_t arg);
 qstr py_get_qstr(py_obj_t arg);
-py_obj_t *py_get_array_fixed_n(py_obj_t o, int n);
-py_obj_t py_obj_new_int(int value);
+py_obj_t *py_get_array_fixed_n(py_obj_t o, machine_int_t n);
+py_obj_t py_obj_new_int(machine_int_t value);
 py_obj_t rt_load_const_str(qstr qstr);
 py_obj_t rt_load_name(qstr qstr);
 py_obj_t rt_load_global(qstr qstr);
diff --git a/py/vm.c b/py/vm.c
index 33970a0f6a..9530a65fd2 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -32,7 +32,8 @@ py_obj_t py_execute_byte_code(const byte *code, const py_obj_t *args, uint n_arg
         // it shouldn't yield
         assert(0);
     }
-    assert(sp == &state[17]);
+    // TODO check fails if, eg, return from within for loop
+    //assert(sp == &state[17]);
     return *sp;
 }
 
@@ -182,6 +183,12 @@ bool py_execute_byte_code_2(const byte *code, const byte **ip_in_out, py_obj_t *
                         ++sp;
                         break;
 
+                    case PYBC_ROT_TWO:
+                        obj1 = sp[0];
+                        sp[0] = sp[1];
+                        sp[1] = obj1;
+                        break;
+
                     case PYBC_ROT_THREE:
                         obj1 = sp[0];
                         sp[0] = sp[1];
-- 
GitLab