From 66028ab6dcd1b2ec9504c3473d817649935a4a1e Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Fri, 3 Jan 2014 14:03:48 +0000
Subject: [PATCH] Basic implementation of import.

import works for simple cases.  Still work to do on finding the right
script, and setting globals/locals correctly when running an imported
function.
---
 py/builtin.c                   | 16 ++-----
 py/builtinimport.c             | 84 ++++++++++++++++++++++++++++++++++
 py/lexer.h                     |  3 ++
 py/lexerunix.c                 | 20 ++++++++
 py/lexerunix.h                 |  2 +
 py/map.h                       |  4 --
 py/obj.h                       |  5 +-
 py/objfun.c                    | 39 ++++++++++------
 py/runtime.c                   | 27 +++++++----
 py/runtime.h                   |  6 +++
 py/showbc.c                    | 14 ++++--
 py/vstr.c                      |  6 ++-
 stm/Makefile                   |  1 +
 stm/lexerstm.c                 |  5 ++
 tests/basics/run-tests         |  2 +-
 tests/basics/tests/import1a.py |  2 +
 tests/basics/tests/import1b.py |  1 +
 unix-cpy/Makefile              |  1 +
 unix/Makefile                  |  1 +
 unix/main.c                    | 12 +++++
 20 files changed, 207 insertions(+), 44 deletions(-)
 create mode 100644 py/builtinimport.c
 create mode 100644 tests/basics/tests/import1a.py
 create mode 100644 tests/basics/tests/import1b.py

diff --git a/py/builtin.c b/py/builtin.c
index 2b94163f1..d29a2bf8c 100644
--- a/py/builtin.c
+++ b/py/builtin.c
@@ -16,30 +16,20 @@
 
 mp_obj_t mp_builtin___build_class__(mp_obj_t o_class_fun, mp_obj_t o_class_name) {
     // we differ from CPython: we set the new __locals__ object here
-    mp_map_t *old_locals = rt_get_map_locals();
+    mp_map_t *old_locals = rt_locals_get();
     mp_map_t *class_locals = mp_map_new(MP_MAP_QSTR, 0);
-    rt_set_map_locals(class_locals);
+    rt_locals_set(class_locals);
 
     // call the class code
     rt_call_function_1(o_class_fun, (mp_obj_t)0xdeadbeef);
 
     // restore old __locals__ object
-    rt_set_map_locals(old_locals);
+    rt_locals_set(old_locals);
 
     // create and return the new class
     return mp_obj_new_class(class_locals);
 }
 
-mp_obj_t mp_builtin___import__(int n, mp_obj_t *args) {
-    printf("import:\n");
-    for (int i = 0; i < n; i++) {
-    printf("  ");
-    mp_obj_print(args[i]);
-    printf("\n");
-    }
-    return mp_const_none;
-}
-
 mp_obj_t mp_builtin___repl_print__(mp_obj_t o) {
     if (o != mp_const_none) {
         mp_obj_print(o);
diff --git a/py/builtinimport.c b/py/builtinimport.c
new file mode 100644
index 000000000..f1479ab12
--- /dev/null
+++ b/py/builtinimport.c
@@ -0,0 +1,84 @@
+#include <stdint.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <stdarg.h>
+#include <string.h>
+#include <assert.h>
+
+#include "nlr.h"
+#include "misc.h"
+#include "mpconfig.h"
+#include "lexer.h"
+#include "lexerunix.h"
+#include "parse.h"
+#include "compile.h"
+#include "obj.h"
+#include "runtime0.h"
+#include "runtime.h"
+#include "map.h"
+#include "builtin.h"
+
+mp_obj_t mp_builtin___import__(int n, mp_obj_t *args) {
+    /*
+    printf("import:\n");
+    for (int i = 0; i < n; i++) {
+        printf("  ");
+        mp_obj_print(args[i]);
+        printf("\n");
+    }
+    */
+
+    // find the file to import
+    qstr mod_name = mp_obj_get_qstr(args[0]);
+    mp_lexer_t *lex = mp_import_open_file(mod_name);
+    if (lex == NULL) {
+        // TODO handle lexer error correctly
+        return mp_const_none;
+    }
+
+    // create a new module object
+    mp_obj_t module_obj = mp_obj_new_module(mp_obj_get_qstr(args[0]));
+
+    // save the old context
+    mp_map_t *old_locals = rt_locals_get();
+    mp_map_t *old_globals = rt_globals_get();
+
+    // set the new context
+    rt_locals_set(mp_obj_module_get_globals(module_obj));
+    rt_globals_set(mp_obj_module_get_globals(module_obj));
+
+    // parse the imported script
+    mp_parse_node_t pn = mp_parse(lex, MP_PARSE_FILE_INPUT);
+    mp_lexer_free(lex);
+
+    if (pn == MP_PARSE_NODE_NULL) {
+        // TODO handle parse error correctly
+        rt_locals_set(old_locals);
+        rt_globals_set(old_globals);
+        return mp_const_none;
+    }
+
+    if (!mp_compile(pn, false)) {
+        // TODO handle compile error correctly
+        rt_locals_set(old_locals);
+        rt_globals_set(old_globals);
+        return mp_const_none;
+    }
+
+    // complied successfully, execute it
+    mp_obj_t module_fun = rt_make_function_from_id(1); // TODO we should return from mp_compile the unique_code_id for the module
+    nlr_buf_t nlr;
+    if (nlr_push(&nlr) == 0) {
+        rt_call_function_0(module_fun);
+        nlr_pop();
+    } else {
+        // exception; restore context and re-raise same exception
+        rt_locals_set(old_locals);
+        rt_globals_set(old_globals);
+        nlr_jump(nlr.ret_val);
+    }
+    rt_locals_set(old_locals);
+    rt_globals_set(old_globals);
+
+    return module_obj;
+}
diff --git a/py/lexer.h b/py/lexer.h
index f58a38e92..27244fde9 100644
--- a/py/lexer.h
+++ b/py/lexer.h
@@ -138,3 +138,6 @@ bool mp_lexer_opt_str(mp_lexer_t *lex, const char *str);
 */
 bool mp_lexer_show_error(mp_lexer_t *lex, const char *msg);
 bool mp_lexer_show_error_pythonic(mp_lexer_t *lex, const char *msg);
+
+// used to import a module; must be implemented for a specific port
+mp_lexer_t *mp_import_open_file(qstr mod_name);
diff --git a/py/lexerunix.c b/py/lexerunix.c
index 398cb792a..14c28c16d 100644
--- a/py/lexerunix.c
+++ b/py/lexerunix.c
@@ -58,3 +58,23 @@ mp_lexer_t *mp_lexer_new_from_file(const char *filename) {
 
     return mp_lexer_new_from_str_len(filename, data, size, true);
 }
+
+/******************************************************************************/
+/* unix implementation of import                                              */
+
+// TODO properly!
+
+static const char *import_base_dir = NULL;
+
+void mp_import_set_directory(const char *dir) {
+    import_base_dir = dir;
+}
+
+mp_lexer_t *mp_import_open_file(qstr mod_name) {
+    vstr_t *vstr = vstr_new();
+    if (import_base_dir != NULL) {
+        vstr_printf(vstr, "%s/", import_base_dir);
+    }
+    vstr_printf(vstr, "%s.py", qstr_str(mod_name));
+    return mp_lexer_new_from_file(vstr_str(vstr)); // TODO does lexer need to copy the string? can we free it here?
+}
diff --git a/py/lexerunix.h b/py/lexerunix.h
index d86f202d5..b422a4306 100644
--- a/py/lexerunix.h
+++ b/py/lexerunix.h
@@ -1,2 +1,4 @@
 mp_lexer_t *mp_lexer_new_from_str_len(const char *src_name, const char *str, uint len, bool free_str);
 mp_lexer_t *mp_lexer_new_from_file(const char *filename);
+
+void mp_import_set_directory(const char *dir);
diff --git a/py/map.h b/py/map.h
index 8ee8429b5..f8ca886aa 100644
--- a/py/map.h
+++ b/py/map.h
@@ -23,10 +23,6 @@ typedef struct _mp_set_t {
     mp_obj_t *table;
 } mp_set_t;
 
-// these are defined in runtime.c
-mp_map_t *rt_get_map_locals(void);
-void rt_set_map_locals(mp_map_t *m);
-
 int get_doubling_prime_greater_or_equal_to(int x);
 void mp_map_init(mp_map_t *map, mp_map_kind_t kind, int n);
 mp_map_t *mp_map_new(mp_map_kind_t kind, int n);
diff --git a/py/obj.h b/py/obj.h
index 6a0cefd91..7b4b0656f 100644
--- a/py/obj.h
+++ b/py/obj.h
@@ -215,11 +215,14 @@ mp_obj_t mp_obj_dict_store(mp_obj_t self_in, mp_obj_t key, mp_obj_t value);
 void mp_obj_set_store(mp_obj_t self_in, mp_obj_t item);
 
 // functions
-typedef struct _mp_obj_fun_native_t { // need this so we can define static objects
+typedef struct _mp_obj_fun_native_t { // need this so we can define const objects (to go in ROM)
     mp_obj_base_t base;
     machine_uint_t n_args_min; // inclusive
     machine_uint_t n_args_max; // inclusive
     void *fun;
+    // TODO add mp_map_t *globals
+    // for const function objects, make an empty, const map
+    // such functions won't be able to access the global scope, but that's probably okay
 } mp_obj_fun_native_t;
 extern const mp_obj_type_t fun_native_type;
 extern const mp_obj_type_t fun_bc_type;
diff --git a/py/objfun.c b/py/objfun.c
index cefc9a95f..e998bd28d 100644
--- a/py/objfun.c
+++ b/py/objfun.c
@@ -7,6 +7,7 @@
 #include "misc.h"
 #include "mpconfig.h"
 #include "obj.h"
+#include "map.h"
 #include "runtime.h"
 #include "bc.h"
 
@@ -129,9 +130,10 @@ mp_obj_t rt_make_function_var_between(int n_args_min, int n_args_max, mp_fun_var
 
 typedef struct _mp_obj_fun_bc_t {
     mp_obj_base_t base;
-    int n_args;
-    uint n_state;
-    const byte *code;
+    mp_map_t *globals;      // the context within which this function was defined
+    int n_args;             // number of arguments this function takes
+    uint n_state;           // total state size for the executing function (incl args, locals, stack)
+    const byte *bytecode;   // bytecode for the function
 } mp_obj_fun_bc_t;
 
 // args are in reverse order in the array
@@ -142,15 +144,17 @@ mp_obj_t fun_bc_call_n(mp_obj_t self_in, int n_args, const mp_obj_t *args) {
         nlr_jump(mp_obj_new_exception_msg_2_args(rt_q_TypeError, "function takes %d positional arguments but %d were given", (const char*)(machine_int_t)self->n_args, (const char*)(machine_int_t)n_args));
     }
 
-    return mp_execute_byte_code(self->code, args, n_args, self->n_state);
-}
-
-void mp_obj_fun_bc_get(mp_obj_t self_in, int *n_args, uint *n_state, const byte **code) {
-    assert(MP_OBJ_IS_TYPE(self_in, &fun_bc_type));
-    mp_obj_fun_bc_t *self = self_in;
-    *n_args = self->n_args;
-    *n_state = self->n_state;
-    *code = self->code;
+    // optimisation: allow the compiler to optimise this tail call for
+    // the common case when the globals don't need to be changed
+    mp_map_t *old_globals = rt_globals_get();
+    if (self->globals == old_globals) {
+        return mp_execute_byte_code(self->bytecode, args, n_args, self->n_state);
+    } else {
+        rt_globals_set(self->globals);
+        mp_obj_t result = mp_execute_byte_code(self->bytecode, args, n_args, self->n_state);
+        rt_globals_set(old_globals);
+        return result;
+    }
 }
 
 const mp_obj_type_t fun_bc_type = {
@@ -170,12 +174,21 @@ const mp_obj_type_t fun_bc_type = {
 mp_obj_t mp_obj_new_fun_bc(int n_args, uint n_state, const byte *code) {
     mp_obj_fun_bc_t *o = m_new_obj(mp_obj_fun_bc_t);
     o->base.type = &fun_bc_type;
+    o->globals = rt_globals_get();
     o->n_args = n_args;
     o->n_state = n_state;
-    o->code = code;
+    o->bytecode = code;
     return o;
 }
 
+void mp_obj_fun_bc_get(mp_obj_t self_in, int *n_args, uint *n_state, const byte **code) {
+    assert(MP_OBJ_IS_TYPE(self_in, &fun_bc_type));
+    mp_obj_fun_bc_t *self = self_in;
+    *n_args = self->n_args;
+    *n_state = self->n_state;
+    *code = self->bytecode;
+}
+
 /******************************************************************************/
 /* inline assembler functions                                                 */
 
diff --git a/py/runtime.c b/py/runtime.c
index 3fae61f6f..a8e55467b 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -281,14 +281,6 @@ void rt_assign_inline_asm_code(int unique_code_id, void *fun, uint len, int n_ar
 #endif
 }
 
-mp_map_t *rt_get_map_locals(void) {
-    return map_locals;
-}
-
-void rt_set_map_locals(mp_map_t *m) {
-    map_locals = m;
-}
-
 static bool fit_small_int(mp_small_int_t o) {
     return true;
 }
@@ -786,6 +778,7 @@ mp_obj_t rt_load_attr(mp_obj_t base, qstr attr) {
     } else if (MP_OBJ_IS_TYPE(base, &instance_type)) {
         return mp_obj_instance_load_attr(base, attr);
     } else if (MP_OBJ_IS_TYPE(base, &module_type)) {
+        DEBUG_OP_printf("lookup module map %p\n", mp_obj_module_get_globals(base));
         mp_map_elem_t *elem = mp_qstr_map_lookup(mp_obj_module_get_globals(base), attr, false);
         if (elem == NULL) {
             // TODO what about generic method lookup?
@@ -913,6 +906,24 @@ mp_obj_t rt_import_from(mp_obj_t module, qstr name) {
     return x;
 }
 
+mp_map_t *rt_locals_get(void) {
+    return map_locals;
+}
+
+void rt_locals_set(mp_map_t *m) {
+    DEBUG_OP_printf("rt_locals_set(%p)\n", m);
+    map_locals = m;
+}
+
+mp_map_t *rt_globals_get(void) {
+    return map_globals;
+}
+
+void rt_globals_set(mp_map_t *m) {
+    DEBUG_OP_printf("rt_globals_set(%p)\n", m);
+    map_globals = m;
+}
+
 // these must correspond to the respective enum
 void *const rt_fun_table[RT_F_NUMBER_OF] = {
     rt_load_const_dec,
diff --git a/py/runtime.h b/py/runtime.h
index 37b036852..cf9180275 100644
--- a/py/runtime.h
+++ b/py/runtime.h
@@ -57,3 +57,9 @@ mp_obj_t rt_getiter(mp_obj_t o);
 mp_obj_t rt_iternext(mp_obj_t o);
 mp_obj_t rt_import_name(qstr name, mp_obj_t fromlist, mp_obj_t level);
 mp_obj_t rt_import_from(mp_obj_t module, qstr name);
+
+struct _mp_map_t;
+struct _mp_map_t *rt_locals_get(void);
+void rt_locals_set(struct _mp_map_t *m);
+struct _mp_map_t *rt_globals_get(void);
+void rt_globals_set(struct _mp_map_t *m);
diff --git a/py/showbc.c b/py/showbc.c
index 15cd05642..a3bfa2833 100644
--- a/py/showbc.c
+++ b/py/showbc.c
@@ -142,12 +142,10 @@ void mp_show_byte_code(const byte *ip, int len) {
                 printf("STORE_NAME %s", qstr_str(qstr));
                 break;
 
-                /*
             case MP_BC_STORE_GLOBAL:
                 DECODE_QSTR;
-                rt_store_global(qstr, POP());
+                printf("STORE_GLOBAL %s", qstr_str(qstr));
                 break;
-                */
 
             case MP_BC_STORE_ATTR:
                 DECODE_QSTR;
@@ -343,6 +341,16 @@ void mp_show_byte_code(const byte *ip, int len) {
                 printf("YIELD_VALUE");
                 break;
 
+            case MP_BC_IMPORT_NAME:
+                DECODE_QSTR;
+                printf("IMPORT NAME %s", qstr_str(qstr));
+                break;
+
+            case MP_BC_IMPORT_FROM:
+                DECODE_QSTR;
+                printf("IMPORT NAME %s", qstr_str(qstr));
+                break;
+
             default:
                 printf("code %p, byte code 0x%02x not implemented\n", ip, op);
                 assert(0);
diff --git a/py/vstr.c b/py/vstr.c
index 98cf02725..80841b24c 100644
--- a/py/vstr.c
+++ b/py/vstr.c
@@ -167,8 +167,12 @@ void vstr_vprintf(vstr_t *vstr, const char *fmt, va_list ap) {
 
     while (1) {
         // try to print in the allocated space
+        // need to make a copy of the va_list because we may call vsnprintf multiple times
         int size = vstr->alloc - vstr->len;
-        int n = vsnprintf(vstr->buf + vstr->len, size, fmt, ap);
+        va_list ap2;
+        va_copy(ap2, ap);
+        int n = vsnprintf(vstr->buf + vstr->len, size, fmt, ap2);
+        va_end(ap2);
 
         // if that worked, return
         if (n > -1 && n < size) {
diff --git a/stm/Makefile b/stm/Makefile
index 6868f85ba..be4ca8b3a 100644
--- a/stm/Makefile
+++ b/stm/Makefile
@@ -82,6 +82,7 @@ PY_O = \
 	objtuple.o \
 	objtype.o \
 	builtin.o \
+	builtinimport.o \
 	vm.o \
 	repl.o \
 
diff --git a/stm/lexerstm.c b/stm/lexerstm.c
index dfb84cca1..661dfb016 100644
--- a/stm/lexerstm.c
+++ b/stm/lexerstm.c
@@ -61,3 +61,8 @@ mp_lexer_t *mp_lexer_new_from_file(const char *filename, mp_lexer_file_buf_t *fb
     fb->pos = 0;
     return mp_lexer_new(filename, fb, (mp_lexer_stream_next_char_t)file_buf_next_char, (mp_lexer_stream_close_t)file_buf_close);
 }
+
+mp_lexer_t *mp_import_open_file(qstr mod_name) {
+    printf("import not implemented\n");
+    return NULL;
+}
diff --git a/tests/basics/run-tests b/tests/basics/run-tests
index 1b027c3e9..72e69c2d8 100755
--- a/tests/basics/run-tests
+++ b/tests/basics/run-tests
@@ -11,7 +11,7 @@ namefailed=
 
 for infile in tests/*.py
 do
-    basename=`basename $infile .c`
+    basename=`basename $infile .py`
     outfile=${basename}.out
     expfile=${basename}.exp
 
diff --git a/tests/basics/tests/import1a.py b/tests/basics/tests/import1a.py
new file mode 100644
index 000000000..16b2d4d30
--- /dev/null
+++ b/tests/basics/tests/import1a.py
@@ -0,0 +1,2 @@
+import import1b
+print(import1b.var)
diff --git a/tests/basics/tests/import1b.py b/tests/basics/tests/import1b.py
new file mode 100644
index 000000000..80479088f
--- /dev/null
+++ b/tests/basics/tests/import1b.py
@@ -0,0 +1 @@
+var = 123
diff --git a/unix-cpy/Makefile b/unix-cpy/Makefile
index 9399a765c..0f20fe31c 100644
--- a/unix-cpy/Makefile
+++ b/unix-cpy/Makefile
@@ -47,6 +47,7 @@ PY_O = \
 	objtuple.o \
 	objtype.o \
 	builtin.o \
+	builtinimport.o \
 	vm.o \
 	showbc.o \
 	repl.o \
diff --git a/unix/Makefile b/unix/Makefile
index b8955d11a..271cf2265 100644
--- a/unix/Makefile
+++ b/unix/Makefile
@@ -54,6 +54,7 @@ PY_O = \
 	objtuple.o \
 	objtype.o \
 	builtin.o \
+	builtinimport.o \
 	vm.o \
 	showbc.o \
 	repl.o \
diff --git a/unix/main.c b/unix/main.c
index 376dbc0c0..c23a8e54c 100644
--- a/unix/main.c
+++ b/unix/main.c
@@ -105,6 +105,18 @@ static void do_repl(void) {
 }
 
 void do_file(const char *file) {
+    // hack: set dir for import based on where this file is
+    {
+        const char * s = strrchr(file, '/');
+        if (s != NULL) {
+            int len = s - file;
+            char *dir = m_new(char, len + 1);
+            memcpy(dir, file, len);
+            dir[len] = '\0';
+            mp_import_set_directory(dir);
+        }
+    }
+
     mp_lexer_t *lex = mp_lexer_new_from_file(file);
     //const char *pysrc = "def f():\n  x=x+1\n  print(42)\n";
     //mp_lexer_t *lex = mp_lexer_from_str_len("<>", pysrc, strlen(pysrc), false);
-- 
GitLab