From 92c06561a3e657924b77aa53725529daddafb343 Mon Sep 17 00:00:00 2001
From: Damien <damien.p.george@gmail.com>
Date: Tue, 22 Oct 2013 22:32:27 +0100
Subject: [PATCH] Improve REPL compount statement detection.

---
 py/lexer.c    |  3 +++
 py/repl.c     | 44 ++++++++++++++++++++++++++++++++++++++++++++
 py/repl.h     |  1 +
 unix/Makefile |  1 +
 unix/main.c   | 27 +++------------------------
 5 files changed, 52 insertions(+), 24 deletions(-)
 create mode 100644 py/repl.c
 create mode 100644 py/repl.h

diff --git a/py/lexer.c b/py/lexer.c
index 88bc0a1aae..cd2e05ece0 100644
--- a/py/lexer.c
+++ b/py/lexer.c
@@ -10,6 +10,9 @@
 
 #define TAB_SIZE (8)
 
+// TODO seems that CPython allows NULL byte in the input stream
+// don't know if that's intentional or not, but we don't allow it
+
 struct _py_lexer_t {
     const char *name;           // name of source
     void *stream_data;          // data for stream
diff --git a/py/repl.c b/py/repl.c
new file mode 100644
index 0000000000..f295aff23d
--- /dev/null
+++ b/py/repl.c
@@ -0,0 +1,44 @@
+#include "misc.h"
+#include "repl.h"
+
+bool str_startswith_word(const char *str, const char *head) {
+    int i;
+    for (i = 0; str[i] && head[i]; i++) {
+        if (str[i] != head[i]) {
+            return false;
+        }
+    }
+    return head[i] == '\0' && (str[i] == '\0' || !g_unichar_isalpha(str[i]));
+}
+
+bool py_repl_is_compound_stmt(const char *line) {
+    // compound if line starts with a certain keyword
+    if (
+           str_startswith_word(line, "if")
+        || str_startswith_word(line, "while")
+        || str_startswith_word(line, "for")
+        || str_startswith_word(line, "true")
+        || str_startswith_word(line, "with")
+        || str_startswith_word(line, "def")
+        || str_startswith_word(line, "class")
+        || str_startswith_word(line, "@")
+       ) {
+        return true;
+    }
+
+    // also "compound" if unmatched open bracket
+    int n_paren = 0;
+    int n_brack = 0;
+    int n_brace = 0;
+    for (const char *l = line; *l; l++) {
+        switch (*l) {
+            case '(': n_paren += 1; break;
+            case ')': n_paren -= 1; break;
+            case '[': n_brack += 1; break;
+            case ']': n_brack -= 1; break;
+            case '{': n_brace += 1; break;
+            case '}': n_brace -= 1; break;
+        }
+    }
+    return n_paren > 0 || n_brack > 0 || n_brace > 0;
+}
diff --git a/py/repl.h b/py/repl.h
new file mode 100644
index 0000000000..014e8609b4
--- /dev/null
+++ b/py/repl.h
@@ -0,0 +1 @@
+bool py_repl_is_compound_stmt(const char *line);
diff --git a/unix/Makefile b/unix/Makefile
index 7c8b5a2b9a..9dd17f8350 100644
--- a/unix/Makefile
+++ b/unix/Makefile
@@ -30,6 +30,7 @@ PY_O = \
 	emitinlinethumb.o \
 	runtime.o \
 	vm.o \
+	repl.o \
 
 OBJ = $(addprefix $(BUILD)/, $(SRC_C:.c=.o) $(PY_O))
 LIB = -lreadline
diff --git a/unix/main.c b/unix/main.c
index f7b06d4f83..12aca6ddf2 100644
--- a/unix/main.c
+++ b/unix/main.c
@@ -10,32 +10,10 @@
 #include "parse.h"
 #include "compile.h"
 #include "runtime.h"
+#include "repl.h"
 
 #include <readline/readline.h>
 
-bool str_startswith_word(const char *str, const char *head) {
-    int i;
-    for (i = 0; str[i] && head[i]; i++) {
-        if (str[i] != head[i]) {
-            return false;
-        }
-    }
-    return head[i] == '\0' && (str[i] == '\0' || !g_unichar_isalpha(str[i]));
-}
-
-bool is_compound_stmt(const char *line) {
-    // TODO also "compound" if unmatched open bracket
-    return
-           str_startswith_word(line, "if")
-        || str_startswith_word(line, "while")
-        || str_startswith_word(line, "for")
-        || str_startswith_word(line, "true")
-        || str_startswith_word(line, "with")
-        || str_startswith_word(line, "def")
-        || str_startswith_word(line, "class")
-        || str_startswith_word(line, "@");
-}
-
 char *str_join(const char *s1, int sep_char, const char *s2) {
     int l1 = strlen(s1);
     int l2 = strlen(s2);
@@ -46,6 +24,7 @@ char *str_join(const char *s1, int sep_char, const char *s2) {
         l1 += 1;
     }
     memcpy(s + l1, s2, l2);
+    s[l1 + l2] = 0;
     return s;
 }
 
@@ -56,7 +35,7 @@ void do_repl() {
             // EOF
             return;
         }
-        if (is_compound_stmt(line)) {
+        if (py_repl_is_compound_stmt(line)) {
             for (;;) {
                 char *line2 = readline("... ");
                 if (line2 == NULL || strlen(line2) == 0) {
-- 
GitLab