From 9a21d2e070c9ee0ef2c003f3a668e635c6ae4401 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Sat, 6 Sep 2014 17:15:34 +0100
Subject: [PATCH] py: Make mpz able to use 16 bits per digit; and 32 on 64-bit
 arch.

Previously, mpz was restricted to using at most 15 bits in each digit,
where a digit was a uint16_t.

With this patch, mpz can use all 16 bits in the uint16_t (improvement
to mpn_div was required).  This gives small inprovements in speed and
RAM usage.  It also yields savings in ROM code size because all of the
digit masking operations become no-ops.

Also, mpz can now use a uint32_t as the digit type, and hence use 32
bits per digit.  This will give decent improvements in mpz speed on
64-bit machines.

Test for big integer division added.
---
 py/mpz.c                    | 87 ++++++++++++++++++++++++++++++-------
 py/mpz.h                    | 29 +++++++++++--
 tests/basics/int_big_div.py |  3 ++
 3 files changed, 99 insertions(+), 20 deletions(-)
 create mode 100644 tests/basics/int_big_div.py

diff --git a/py/mpz.c b/py/mpz.c
index 8e6aecbca..186229569 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -37,7 +37,9 @@
 #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
 
 #define DIG_SIZE (MPZ_DIG_SIZE)
-#define DIG_MASK ((1 << DIG_SIZE) - 1)
+#define DIG_MASK ((1L << DIG_SIZE) - 1)
+#define DIG_MSB  (1L << (DIG_SIZE - 1))
+#define DIG_BASE (1L << DIG_SIZE)
 
 /*
  mpz is an arbitrary precision integer type with a public API.
@@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *
     if (ilen > jlen) { return 1; }
 
     for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
-        mp_int_t cmp = *(--idig) - *(--jdig);
+        mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);
         if (cmp < 0) { return -1; }
         if (cmp > 0) { return 1; }
     }
@@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
     for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) {
         mpz_dbl_dig_t d = *jdig;
         if (i > 1) {
-            d |= jdig[1] << DIG_SIZE;
+            d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
         }
         d >>= n_part;
         *idig = d & DIG_MASK;
@@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
     jlen -= klen;
 
     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
-        carry += *jdig + *kdig;
+        carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;
         *idig = carry & DIG_MASK;
         carry >>= DIG_SIZE;
     }
@@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
     jlen -= klen;
 
     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
-        borrow += *jdig - *kdig;
+        borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;
         *idig = borrow & DIG_MASK;
         borrow >>= DIG_SIZE;
     }
@@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t
     mpz_dbl_dig_t carry = dadd;
 
     for (; ilen > 0; --ilen, ++idig) {
-        carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
+        carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
         *idig = carry & DIG_MASK;
         carry >>= DIG_SIZE;
     }
@@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
 
         mp_uint_t jl = jlen;
         for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
-            carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
+            carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
             *id = carry & DIG_MASK;
             carry >>= DIG_SIZE;
         }
@@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
     // count number of leading zeros in leading digit of denominator
     {
         mpz_dig_t d = den_dig[den_len - 1];
-        while ((d & (1 << (DIG_SIZE - 1))) == 0) {
+        while ((d & DIG_MSB) == 0) {
             d <<= 1;
             ++norm_shift;
         }
@@ -412,21 +414,36 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
 
     // keep going while we have enough digits to divide
     while (*num_len > den_len) {
-        mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1];
+        mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];
 
         // get approximate quotient
         quo /= lead_den_digit;
 
-        // multiply quo by den and subtract from num get remainder
-        {
+        // Multiply quo by den and subtract from num to get remainder.
+        // We have different code here to handle different compile-time
+        // configurations of mpz:
+        //
+        //   1. DIG_SIZE is stricly less than half the number of bits
+        //      available in mpz_dbl_dig_t.  In this case we can use a
+        //      slightly more optimal (in time and space) routine that
+        //      uses the extra bits in mpz_dbl_dig_signed_t to store a
+        //      sign bit.
+        //
+        //   2. DIG_SIZE is exactly half the number of bits available in
+        //      mpz_dbl_dig_t.  In this (common) case we need to be careful
+        //      not to overflow the borrow variable.  And the shifting of
+        //      borrow needs some special logic (it's a shift right with
+        //      round up).
+
+        if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
             mpz_dbl_dig_signed_t borrow = 0;
 
             for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
-                borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16
+                borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
                 *n = borrow & DIG_MASK;
                 borrow >>= DIG_SIZE;
             }
-            borrow += *num_dig; // will overflow if DIG_SIZE >= 16
+            borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
             *num_dig = borrow & DIG_MASK;
             borrow >>= DIG_SIZE;
 
@@ -434,7 +451,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
             for (; borrow != 0; --quo) {
                 mpz_dbl_dig_t carry = 0;
                 for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
-                    carry += *n + *d;
+                    carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
                     *n = carry & DIG_MASK;
                     carry >>= DIG_SIZE;
                 }
@@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
 
                 borrow += carry;
             }
+        } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
+            mpz_dbl_dig_t borrow = 0;
+
+            for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
+                mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
+                if (x >= *n || *n - x <= borrow) {
+                    borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
+                    *n = (-borrow) & DIG_MASK;
+                    borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
+                } else {
+                    *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
+                    borrow = 0;
+                }
+            }
+            if (borrow >= *num_dig) {
+                borrow -= (mpz_dbl_dig_t)*num_dig;
+                *num_dig = (-borrow) & DIG_MASK;
+                borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
+            } else {
+                *num_dig = (*num_dig - borrow) & DIG_MASK;
+                borrow = 0;
+            }
+
+            // adjust quotient if it is too big
+            for (; borrow != 0; --quo) {
+                mpz_dbl_dig_t carry = 0;
+                for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
+                    carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
+                    *n = carry & DIG_MASK;
+                    carry >>= DIG_SIZE;
+                }
+                carry += (mpz_dbl_dig_t)*num_dig;
+                *num_dig = carry & DIG_MASK;
+                carry >>= DIG_SIZE;
+
+                //assert(borrow >= carry); // enable this to check the logic
+                borrow -= carry;
+            }
         }
 
         // store this digit of the quotient
@@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
     mpz_dig_t *d = i->dig + i->len;
 
     while (--d >= i->dig) {
-        if (val > ((~0) >> DIG_SIZE)) {
+        if (val > (~(WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {
             // will overflow
             return false;
         }
@@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) {
     mpz_dig_t *d = i->dig + i->len;
 
     while (--d >= i->dig) {
-        val = val * (1 << DIG_SIZE) + *d;
+        val = val * DIG_BASE + *d;
     }
 
     if (i->neg != 0) {
diff --git a/py/mpz.h b/py/mpz.h
index 6eaaf378a..79e5ea231 100644
--- a/py/mpz.h
+++ b/py/mpz.h
@@ -24,9 +24,34 @@
  * THE SOFTWARE.
  */
 
+// This mpz module implements arbitrary precision integers.
+//
+// The storage for each digit is defined by mpz_dig_t.  The actual number of
+// bits in mpz_dig_t that are used is defined by MPZ_DIG_SIZE.  The machine must
+// also provide a type that is twice as wide as mpz_dig_t, in both signed and
+// unsigned versions.
+//
+// MPZ_DIG_SIZE can be between 4 and 8*sizeof(mpz_dig_t), but it makes most
+// sense to have it as large as possible.  Below, the type is auto-detected
+// depending on the machine, but it (and MPZ_DIG_SIZE) can be freely changed so
+// long as the constraints mentioned above are met.
+
+#if defined(__x86_64__)
+// 64-bit machine, using 32-bit storage for digits
+typedef uint32_t mpz_dig_t;
+typedef uint64_t mpz_dbl_dig_t;
+typedef int64_t mpz_dbl_dig_signed_t;
+#define MPZ_DIG_SIZE (32)
+#else
+// 32-bit machine, using 16-bit storage for digits
 typedef uint16_t mpz_dig_t;
 typedef uint32_t mpz_dbl_dig_t;
 typedef int32_t mpz_dbl_dig_signed_t;
+#define MPZ_DIG_SIZE (16)
+#endif
+
+#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1)
+#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1)
 
 typedef struct _mpz_t {
     mp_uint_t neg : 1;
@@ -36,10 +61,6 @@ typedef struct _mpz_t {
     mpz_dig_t *dig;
 } mpz_t;
 
-#define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15
-#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1)
-#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1)
-
 // convenience macro to declare an mpz with a digit array from the stack, initialised by an integer
 #define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val);
 
diff --git a/tests/basics/int_big_div.py b/tests/basics/int_big_div.py
new file mode 100644
index 000000000..8dacf495d
--- /dev/null
+++ b/tests/basics/int_big_div.py
@@ -0,0 +1,3 @@
+for lhs in (1000000000000000000000000, 10000000000100000000000000, 10012003400000000000000007, 12349083434598210349871029923874109871234789):
+    for rhs in range(1, 555):
+        print(lhs // rhs)
-- 
GitLab