From 1e8ca3a3cf442733f9210be4d5f6a5b63135a16d Mon Sep 17 00:00:00 2001
From: Sebastian Plamauer <oeplse@gmail.com>
Date: Tue, 14 Jul 2015 14:44:31 +0200
Subject: [PATCH] modbuiltins: Implement round() to precision.

---
 py/modbuiltins.c                   | 9 +++++----
 tests/basics/builtin_round.py      | 2 +-
 tests/float/builtin_float_round.py | 5 +++--
 3 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/py/modbuiltins.c b/py/modbuiltins.c
index 16b0e320c..d0c6130d5 100644
--- a/py/modbuiltins.c
+++ b/py/modbuiltins.c
@@ -431,7 +431,6 @@ STATIC mp_obj_t mp_builtin_repr(mp_obj_t o_in) {
 MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_repr_obj, mp_builtin_repr);
 
 STATIC mp_obj_t mp_builtin_round(mp_uint_t n_args, const mp_obj_t *args) {
-    // TODO really support second arg
     mp_obj_t o_in = args[0];
     if (MP_OBJ_IS_INT(o_in)) {
         return o_in;
@@ -440,9 +439,11 @@ STATIC mp_obj_t mp_builtin_round(mp_uint_t n_args, const mp_obj_t *args) {
     mp_int_t num_dig = 0;
     if (n_args > 1) {
         num_dig = mp_obj_get_int(args[1]);
-        if (num_dig > 0) {
-            mp_not_implemented("round(..., N>0)");
-        }
+        mp_float_t val = mp_obj_get_float(o_in);
+        mp_float_t mult = MICROPY_FLOAT_C_FUN(pow)(10, num_dig);
+        // TODO may lead to overflow
+        mp_float_t rounded = MICROPY_FLOAT_C_FUN(round)(val * mult) / mult;
+        return mp_obj_new_float(rounded);
     }
     mp_float_t val = mp_obj_get_float(o_in);
     mp_float_t rounded = MICROPY_FLOAT_C_FUN(round)(val);
diff --git a/tests/basics/builtin_round.py b/tests/basics/builtin_round.py
index 7f0edfe84..579bae39d 100644
--- a/tests/basics/builtin_round.py
+++ b/tests/basics/builtin_round.py
@@ -2,7 +2,7 @@
 
 tests = [
     False, True,
-    0, 1, -1, 10,
+    0, 1, -1, 10
 ]
 for t in tests:
     print(round(t))
diff --git a/tests/float/builtin_float_round.py b/tests/float/builtin_float_round.py
index 6759d0fd5..4419b744b 100644
--- a/tests/float/builtin_float_round.py
+++ b/tests/float/builtin_float_round.py
@@ -2,10 +2,11 @@
 
 # check basic cases
 tests = [
-    0.0, 1.0, 0.1, -0.1, 123.4, 123.6, -123.4, -123.6
+    [0.0], [1.0], [0.1], [-0.1], [123.4], [123.6], [-123.4], [-123.6],
+    [1.234567, 5], [1.23456, 1], [1.23456, 0], [1234.56, -2]
 ]
 for t in tests:
-    print(round(t))
+    print(round(*t))
 
 # check .5 cases
 for i in range(11):
-- 
GitLab