From d4b75f6b6822885e331c69a74e56e23af40a6264 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Mon, 4 Sep 2017 14:16:27 +1000
Subject: [PATCH] py/obj: Fix comparison of float/complex NaN with itself.

IEEE floating point is specified such that a comparison of NaN with itself
returns false, and Python respects these semantics.  This patch makes uPy
also have these semantics.  The fix has a minor impact on the speed of the
object-equality fast-path, but that seems to be unavoidable and it's much
more important to have correct behaviour (especially in this case where
the wrong answer for nan==nan is silently returned).
---
 py/obj.c                | 11 ++++++++++-
 tests/float/complex1.py |  5 +++++
 tests/float/float1.py   |  5 +++++
 3 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/py/obj.c b/py/obj.c
index 90ce47e8f..857fe373f 100644
--- a/py/obj.c
+++ b/py/obj.c
@@ -162,7 +162,16 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
 // comparison returns NotImplemented, == and != are decided by comparing the object
 // pointer."
 bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
-    if (o1 == o2) {
+    // Float (and complex) NaN is never equal to anything, not even itself,
+    // so we must have a special check here to cover those cases.
+    if (o1 == o2
+        #if MICROPY_PY_BUILTINS_FLOAT
+        && !mp_obj_is_float(o1)
+        #endif
+        #if MICROPY_PY_BUILTINS_COMPLEX
+        && !MP_OBJ_IS_TYPE(o1, &mp_type_complex)
+        #endif
+        ) {
         return true;
     }
     if (o1 == mp_const_none || o2 == mp_const_none) {
diff --git a/tests/float/complex1.py b/tests/float/complex1.py
index a6038de04..7f0b317b3 100644
--- a/tests/float/complex1.py
+++ b/tests/float/complex1.py
@@ -37,6 +37,11 @@ ans = 1j ** 2.5j; print("%.5g %.5g" % (ans.real, ans.imag))
 print(1j == 1)
 print(1j == 1j)
 
+# comparison of nan is special
+nan = float('nan') * 1j
+print(nan == 1j)
+print(nan == nan)
+
 # builtin abs
 print(abs(1j))
 print("%.5g" % abs(1j + 2))
diff --git a/tests/float/float1.py b/tests/float/float1.py
index 93f6f014c..137dacc23 100644
--- a/tests/float/float1.py
+++ b/tests/float/float1.py
@@ -60,6 +60,11 @@ print(1.2 <= -3.4)
 print(1.2 >= 3.4)
 print(1.2 >= -3.4)
 
+# comparison of nan is special
+nan = float('nan')
+print(nan == 1.2)
+print(nan == nan)
+
 try:
     1.0 / 0
 except ZeroDivisionError:
-- 
GitLab