From 10830059c5d3651abdb2d3532b28a9bb0a9425ee Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Sat, 4 Aug 2018 22:03:49 +1000
Subject: [PATCH] py/emitnative: Fix x86 native zero checks by comparing full
 word.

On x86 archs (both 32 and 64 bit) a bool return value only sets the 8-bit
al register, and the higher bits of the ax register have an undefined
value.  When testing the return value of such cases it is required to just
test al for zero/non-zero.  On the other hand, checking for truth or
zero/non-zero on an integer return value requires checking all bits of the
register.  These two cases must be distinguished and handled correctly in
generated native code.  This patch makes sure of this.

For other supported native archs (ARM, Thumb2, Xtensa) there is no such
distinction and this patch does not change anything for them.
---
 py/asmarm.h     |  4 ++--
 py/asmthumb.h   |  4 ++--
 py/asmx64.c     |  5 +++++
 py/asmx64.h     | 17 +++++++++++++----
 py/asmx86.c     |  5 +++++
 py/asmx86.h     | 17 +++++++++++++----
 py/asmxtensa.h  |  4 ++--
 py/emitnative.c | 12 ++++++------
 8 files changed, 48 insertions(+), 20 deletions(-)

diff --git a/py/asmarm.h b/py/asmarm.h
index 871e35820..5c1e2ba58 100644
--- a/py/asmarm.h
+++ b/py/asmarm.h
@@ -150,12 +150,12 @@ void asm_arm_bl_ind(asm_arm_t *as, void *fun_ptr, uint fun_id, uint reg_temp);
 #define ASM_EXIT            asm_arm_exit
 
 #define ASM_JUMP            asm_arm_b_label
-#define ASM_JUMP_IF_REG_ZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_ZERO(as, reg, label, bool_test) \
     do { \
         asm_arm_cmp_reg_i8(as, reg, 0); \
         asm_arm_bcc_label(as, ASM_ARM_CC_EQ, label); \
     } while (0)
-#define ASM_JUMP_IF_REG_NONZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_NONZERO(as, reg, label, bool_test) \
     do { \
         asm_arm_cmp_reg_i8(as, reg, 0); \
         asm_arm_bcc_label(as, ASM_ARM_CC_NE, label); \
diff --git a/py/asmthumb.h b/py/asmthumb.h
index 8a7df5d50..9d25b973f 100644
--- a/py/asmthumb.h
+++ b/py/asmthumb.h
@@ -267,12 +267,12 @@ void asm_thumb_bl_ind(asm_thumb_t *as, void *fun_ptr, uint fun_id, uint reg_temp
 #define ASM_EXIT            asm_thumb_exit
 
 #define ASM_JUMP            asm_thumb_b_label
-#define ASM_JUMP_IF_REG_ZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_ZERO(as, reg, label, bool_test) \
     do { \
         asm_thumb_cmp_rlo_i8(as, reg, 0); \
         asm_thumb_bcc_label(as, ASM_THUMB_CC_EQ, label); \
     } while (0)
-#define ASM_JUMP_IF_REG_NONZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_NONZERO(as, reg, label, bool_test) \
     do { \
         asm_thumb_cmp_rlo_i8(as, reg, 0); \
         asm_thumb_bcc_label(as, ASM_THUMB_CC_NE, label); \
diff --git a/py/asmx64.c b/py/asmx64.c
index c900a08d1..2389aad44 100644
--- a/py/asmx64.c
+++ b/py/asmx64.c
@@ -73,6 +73,7 @@
 #define OPCODE_CMP_R64_WITH_RM64 (0x39) /* /r */
 //#define OPCODE_CMP_RM32_WITH_R32 (0x3b)
 #define OPCODE_TEST_R8_WITH_RM8  (0x84) /* /r */
+#define OPCODE_TEST_R64_WITH_RM64 (0x85) /* /r */
 #define OPCODE_JMP_REL8          (0xeb)
 #define OPCODE_JMP_REL32         (0xe9)
 #define OPCODE_JCC_REL8          (0x70) /* | jcc type */
@@ -471,6 +472,10 @@ void asm_x64_test_r8_with_r8(asm_x64_t *as, int src_r64_a, int src_r64_b) {
     asm_x64_write_byte_2(as, OPCODE_TEST_R8_WITH_RM8, MODRM_R64(src_r64_a) | MODRM_RM_REG | MODRM_RM_R64(src_r64_b));
 }
 
+void asm_x64_test_r64_with_r64(asm_x64_t *as, int src_r64_a, int src_r64_b) {
+    asm_x64_generic_r64_r64(as, src_r64_b, src_r64_a, OPCODE_TEST_R64_WITH_RM64);
+}
+
 void asm_x64_setcc_r8(asm_x64_t *as, int jcc_type, int dest_r8) {
     assert(dest_r8 < 8);
     asm_x64_write_byte_3(as, OPCODE_SETCC_RM8_A, OPCODE_SETCC_RM8_B | jcc_type, MODRM_R64(0) | MODRM_RM_REG | MODRM_RM_R64(dest_r8));
diff --git a/py/asmx64.h b/py/asmx64.h
index 2fbbfa9ff..4d7281d18 100644
--- a/py/asmx64.h
+++ b/py/asmx64.h
@@ -104,6 +104,7 @@ void asm_x64_sub_r64_r64(asm_x64_t* as, int dest_r64, int src_r64);
 void asm_x64_mul_r64_r64(asm_x64_t* as, int dest_r64, int src_r64);
 void asm_x64_cmp_r64_with_r64(asm_x64_t* as, int src_r64_a, int src_r64_b);
 void asm_x64_test_r8_with_r8(asm_x64_t* as, int src_r64_a, int src_r64_b);
+void asm_x64_test_r64_with_r64(asm_x64_t *as, int src_r64_a, int src_r64_b);
 void asm_x64_setcc_r8(asm_x64_t* as, int jcc_type, int dest_r8);
 void asm_x64_jmp_label(asm_x64_t* as, mp_uint_t label);
 void asm_x64_jcc_label(asm_x64_t* as, int jcc_type, mp_uint_t label);
@@ -145,14 +146,22 @@ void asm_x64_call_ind(asm_x64_t* as, void* ptr, int temp_r32);
 #define ASM_EXIT            asm_x64_exit
 
 #define ASM_JUMP            asm_x64_jmp_label
-#define ASM_JUMP_IF_REG_ZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_ZERO(as, reg, label, bool_test) \
     do { \
-        asm_x64_test_r8_with_r8(as, reg, reg); \
+        if (bool_test) { \
+            asm_x64_test_r8_with_r8((as), (reg), (reg)); \
+        } else { \
+            asm_x64_test_r64_with_r64((as), (reg), (reg)); \
+        } \
         asm_x64_jcc_label(as, ASM_X64_CC_JZ, label); \
     } while (0)
-#define ASM_JUMP_IF_REG_NONZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_NONZERO(as, reg, label, bool_test) \
     do { \
-        asm_x64_test_r8_with_r8(as, reg, reg); \
+        if (bool_test) { \
+            asm_x64_test_r8_with_r8((as), (reg), (reg)); \
+        } else { \
+            asm_x64_test_r64_with_r64((as), (reg), (reg)); \
+        } \
         asm_x64_jcc_label(as, ASM_X64_CC_JNZ, label); \
     } while (0)
 #define ASM_JUMP_IF_REG_EQ(as, reg1, reg2, label) \
diff --git a/py/asmx86.c b/py/asmx86.c
index 3938baaac..d0d4140ab 100644
--- a/py/asmx86.c
+++ b/py/asmx86.c
@@ -73,6 +73,7 @@
 #define OPCODE_CMP_R32_WITH_RM32 (0x39)
 //#define OPCODE_CMP_RM32_WITH_R32 (0x3b)
 #define OPCODE_TEST_R8_WITH_RM8  (0x84) /* /r */
+#define OPCODE_TEST_R32_WITH_RM32 (0x85) /* /r */
 #define OPCODE_JMP_REL8          (0xeb)
 #define OPCODE_JMP_REL32         (0xe9)
 #define OPCODE_JCC_REL8          (0x70) /* | jcc type */
@@ -334,6 +335,10 @@ void asm_x86_test_r8_with_r8(asm_x86_t *as, int src_r32_a, int src_r32_b) {
     asm_x86_write_byte_2(as, OPCODE_TEST_R8_WITH_RM8, MODRM_R32(src_r32_a) | MODRM_RM_REG | MODRM_RM_R32(src_r32_b));
 }
 
+void asm_x86_test_r32_with_r32(asm_x86_t *as, int src_r32_a, int src_r32_b) {
+    asm_x86_generic_r32_r32(as, src_r32_b, src_r32_a, OPCODE_TEST_R32_WITH_RM32);
+}
+
 void asm_x86_setcc_r8(asm_x86_t *as, mp_uint_t jcc_type, int dest_r8) {
     asm_x86_write_byte_3(as, OPCODE_SETCC_RM8_A, OPCODE_SETCC_RM8_B | jcc_type, MODRM_R32(0) | MODRM_RM_REG | MODRM_RM_R32(dest_r8));
 }
diff --git a/py/asmx86.h b/py/asmx86.h
index 09559850c..72b122ad0 100644
--- a/py/asmx86.h
+++ b/py/asmx86.h
@@ -101,6 +101,7 @@ void asm_x86_sub_r32_r32(asm_x86_t* as, int dest_r32, int src_r32);
 void asm_x86_mul_r32_r32(asm_x86_t* as, int dest_r32, int src_r32);
 void asm_x86_cmp_r32_with_r32(asm_x86_t* as, int src_r32_a, int src_r32_b);
 void asm_x86_test_r8_with_r8(asm_x86_t* as, int src_r32_a, int src_r32_b);
+void asm_x86_test_r32_with_r32(asm_x86_t* as, int src_r32_a, int src_r32_b);
 void asm_x86_setcc_r8(asm_x86_t* as, mp_uint_t jcc_type, int dest_r8);
 void asm_x86_jmp_label(asm_x86_t* as, mp_uint_t label);
 void asm_x86_jcc_label(asm_x86_t* as, mp_uint_t jcc_type, mp_uint_t label);
@@ -143,14 +144,22 @@ void asm_x86_call_ind(asm_x86_t* as, void* ptr, mp_uint_t n_args, int temp_r32);
 #define ASM_EXIT            asm_x86_exit
 
 #define ASM_JUMP            asm_x86_jmp_label
-#define ASM_JUMP_IF_REG_ZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_ZERO(as, reg, label, bool_test) \
     do { \
-        asm_x86_test_r8_with_r8(as, reg, reg); \
+        if (bool_test) { \
+            asm_x86_test_r8_with_r8(as, reg, reg); \
+        } else { \
+            asm_x86_test_r32_with_r32(as, reg, reg); \
+        } \
         asm_x86_jcc_label(as, ASM_X86_CC_JZ, label); \
     } while (0)
-#define ASM_JUMP_IF_REG_NONZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_NONZERO(as, reg, label, bool_test) \
     do { \
-        asm_x86_test_r8_with_r8(as, reg, reg); \
+        if (bool_test) { \
+            asm_x86_test_r8_with_r8(as, reg, reg); \
+        } else { \
+            asm_x86_test_r32_with_r32(as, reg, reg); \
+        } \
         asm_x86_jcc_label(as, ASM_X86_CC_JNZ, label); \
     } while (0)
 #define ASM_JUMP_IF_REG_EQ(as, reg1, reg2, label) \
diff --git a/py/asmxtensa.h b/py/asmxtensa.h
index e6d4158cb..041844e6d 100644
--- a/py/asmxtensa.h
+++ b/py/asmxtensa.h
@@ -268,9 +268,9 @@ void asm_xtensa_mov_reg_local_addr(asm_xtensa_t *as, uint reg_dest, int local_nu
 #define ASM_EXIT            asm_xtensa_exit
 
 #define ASM_JUMP            asm_xtensa_j_label
-#define ASM_JUMP_IF_REG_ZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_ZERO(as, reg, label, bool_test) \
     asm_xtensa_bccz_reg_label(as, ASM_XTENSA_CCZ_EQ, reg, label)
-#define ASM_JUMP_IF_REG_NONZERO(as, reg, label) \
+#define ASM_JUMP_IF_REG_NONZERO(as, reg, label, bool_test) \
     asm_xtensa_bccz_reg_label(as, ASM_XTENSA_CCZ_NE, reg, label)
 #define ASM_JUMP_IF_REG_EQ(as, reg1, reg2, label) \
     asm_xtensa_bcc_reg_reg_label(as, ASM_XTENSA_CC_EQ, reg1, reg2, label)
diff --git a/py/emitnative.c b/py/emitnative.c
index 00d322d75..7071062a7 100644
--- a/py/emitnative.c
+++ b/py/emitnative.c
@@ -1547,9 +1547,9 @@ STATIC void emit_native_jump_helper(emit_t *emit, bool cond, mp_uint_t label, bo
     need_stack_settled(emit);
     // Emit the jump
     if (cond) {
-        ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label);
+        ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label, vtype == VTYPE_PYOBJ);
     } else {
-        ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label);
+        ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label, vtype == VTYPE_PYOBJ);
     }
     if (!pop) {
         adjust_stack(emit, -1);
@@ -1607,7 +1607,7 @@ STATIC void emit_native_setup_with(emit_t *emit, mp_uint_t label) {
     need_stack_settled(emit);
     emit_get_stack_pointer_to_reg_for_push(emit, REG_ARG_1, sizeof(nlr_buf_t) / sizeof(mp_uint_t)); // arg1 = pointer to nlr buf
     emit_call(emit, MP_F_NLR_PUSH);
-    ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label);
+    ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label, true);
 
     emit_access_stack(emit, sizeof(nlr_buf_t) / sizeof(mp_uint_t) + 1, &vtype, REG_RET); // access return value of __enter__
     emit_post_push_reg(emit, VTYPE_PYOBJ, REG_RET); // push return value of __enter__
@@ -1624,7 +1624,7 @@ STATIC void emit_native_setup_block(emit_t *emit, mp_uint_t label, int kind) {
         need_stack_settled(emit);
         emit_get_stack_pointer_to_reg_for_push(emit, REG_ARG_1, sizeof(nlr_buf_t) / sizeof(mp_uint_t)); // arg1 = pointer to nlr buf
         emit_call(emit, MP_F_NLR_PUSH);
-        ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label);
+        ASM_JUMP_IF_REG_NONZERO(emit->as, REG_RET, label, true);
         emit_post(emit);
     }
 }
@@ -1688,7 +1688,7 @@ STATIC void emit_native_with_cleanup(emit_t *emit, mp_uint_t label) {
         ASM_MOV_REG_REG(emit->as, REG_ARG_1, REG_RET);
     }
     emit_call(emit, MP_F_OBJ_IS_TRUE);
-    ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label + 1);
+    ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label + 1, true);
 
     // replace exc with None
     emit_pre_pop_discard(emit);
@@ -1736,7 +1736,7 @@ STATIC void emit_native_for_iter(emit_t *emit, mp_uint_t label) {
     emit_call(emit, MP_F_NATIVE_ITERNEXT);
     #ifdef NDEBUG
     MP_STATIC_ASSERT(MP_OBJ_STOP_ITERATION == 0);
-    ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label);
+    ASM_JUMP_IF_REG_ZERO(emit->as, REG_RET, label, false);
     #else
     ASM_MOV_REG_IMM(emit->as, REG_TEMP1, (mp_uint_t)MP_OBJ_STOP_ITERATION);
     ASM_JUMP_IF_REG_EQ(emit->as, REG_RET, REG_TEMP1, label);
-- 
GitLab