From 9a21d2e070c9ee0ef2c003f3a668e635c6ae4401 Mon Sep 17 00:00:00 2001 From: Damien George <damien.p.george@gmail.com> Date: Sat, 6 Sep 2014 17:15:34 +0100 Subject: [PATCH] py: Make mpz able to use 16 bits per digit; and 32 on 64-bit arch. Previously, mpz was restricted to using at most 15 bits in each digit, where a digit was a uint16_t. With this patch, mpz can use all 16 bits in the uint16_t (improvement to mpn_div was required). This gives small inprovements in speed and RAM usage. It also yields savings in ROM code size because all of the digit masking operations become no-ops. Also, mpz can now use a uint32_t as the digit type, and hence use 32 bits per digit. This will give decent improvements in mpz speed on 64-bit machines. Test for big integer division added. --- py/mpz.c | 87 ++++++++++++++++++++++++++++++------- py/mpz.h | 29 +++++++++++-- tests/basics/int_big_div.py | 3 ++ 3 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 tests/basics/int_big_div.py diff --git a/py/mpz.c b/py/mpz.c index 8e6aecbca..186229569 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -37,7 +37,9 @@ #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ #define DIG_SIZE (MPZ_DIG_SIZE) -#define DIG_MASK ((1 << DIG_SIZE) - 1) +#define DIG_MASK ((1L << DIG_SIZE) - 1) +#define DIG_MSB (1L << (DIG_SIZE - 1)) +#define DIG_BASE (1L << DIG_SIZE) /* mpz is an arbitrary precision integer type with a public API. @@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t * if (ilen > jlen) { return 1; } for (idig += ilen, jdig += ilen; ilen > 0; --ilen) { - mp_int_t cmp = *(--idig) - *(--jdig); + mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig); if (cmp < 0) { return -1; } if (cmp > 0) { return 1; } } @@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) { mpz_dbl_dig_t d = *jdig; if (i > 1) { - d |= jdig[1] << DIG_SIZE; + d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE; } d >>= n_part; *idig = d & DIG_MASK; @@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, jlen -= klen; for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { - carry += *jdig + *kdig; + carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig; *idig = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, jlen -= klen; for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { - borrow += *jdig - *kdig; + borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig; *idig = borrow & DIG_MASK; borrow >>= DIG_SIZE; } @@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t mpz_dbl_dig_t carry = dadd; for (; ilen > 0; --ilen, ++idig) { - carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 + carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2 *idig = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d mp_uint_t jl = jlen; for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) { - carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 + carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2 *id = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // count number of leading zeros in leading digit of denominator { mpz_dig_t d = den_dig[den_len - 1]; - while ((d & (1 << (DIG_SIZE - 1))) == 0) { + while ((d & DIG_MSB) == 0) { d <<= 1; ++norm_shift; } @@ -412,21 +414,36 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // keep going while we have enough digits to divide while (*num_len > den_len) { - mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1]; + mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1]; // get approximate quotient quo /= lead_den_digit; - // multiply quo by den and subtract from num get remainder - { + // Multiply quo by den and subtract from num to get remainder. + // We have different code here to handle different compile-time + // configurations of mpz: + // + // 1. DIG_SIZE is stricly less than half the number of bits + // available in mpz_dbl_dig_t. In this case we can use a + // slightly more optimal (in time and space) routine that + // uses the extra bits in mpz_dbl_dig_signed_t to store a + // sign bit. + // + // 2. DIG_SIZE is exactly half the number of bits available in + // mpz_dbl_dig_t. In this (common) case we need to be careful + // not to overflow the borrow variable. And the shifting of + // borrow needs some special logic (it's a shift right with + // round up). + + if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) { mpz_dbl_dig_signed_t borrow = 0; for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16 + borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 *n = borrow & DIG_MASK; borrow >>= DIG_SIZE; } - borrow += *num_dig; // will overflow if DIG_SIZE >= 16 + borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 *num_dig = borrow & DIG_MASK; borrow >>= DIG_SIZE; @@ -434,7 +451,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, for (; borrow != 0; --quo) { mpz_dbl_dig_t carry = 0; for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - carry += *n + *d; + carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; *n = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, borrow += carry; } + } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2 + mpz_dbl_dig_t borrow = 0; + + for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { + mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d); + if (x >= *n || *n - x <= borrow) { + borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n; + *n = (-borrow) & DIG_MASK; + borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up + } else { + *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK; + borrow = 0; + } + } + if (borrow >= *num_dig) { + borrow -= (mpz_dbl_dig_t)*num_dig; + *num_dig = (-borrow) & DIG_MASK; + borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up + } else { + *num_dig = (*num_dig - borrow) & DIG_MASK; + borrow = 0; + } + + // adjust quotient if it is too big + for (; borrow != 0; --quo) { + mpz_dbl_dig_t carry = 0; + for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { + carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; + *n = carry & DIG_MASK; + carry >>= DIG_SIZE; + } + carry += (mpz_dbl_dig_t)*num_dig; + *num_dig = carry & DIG_MASK; + carry >>= DIG_SIZE; + + //assert(borrow >= carry); // enable this to check the logic + borrow -= carry; + } } // store this digit of the quotient @@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) { mpz_dig_t *d = i->dig + i->len; while (--d >= i->dig) { - if (val > ((~0) >> DIG_SIZE)) { + if (val > (~(WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) { // will overflow return false; } @@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) { mpz_dig_t *d = i->dig + i->len; while (--d >= i->dig) { - val = val * (1 << DIG_SIZE) + *d; + val = val * DIG_BASE + *d; } if (i->neg != 0) { diff --git a/py/mpz.h b/py/mpz.h index 6eaaf378a..79e5ea231 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -24,9 +24,34 @@ * THE SOFTWARE. */ +// This mpz module implements arbitrary precision integers. +// +// The storage for each digit is defined by mpz_dig_t. The actual number of +// bits in mpz_dig_t that are used is defined by MPZ_DIG_SIZE. The machine must +// also provide a type that is twice as wide as mpz_dig_t, in both signed and +// unsigned versions. +// +// MPZ_DIG_SIZE can be between 4 and 8*sizeof(mpz_dig_t), but it makes most +// sense to have it as large as possible. Below, the type is auto-detected +// depending on the machine, but it (and MPZ_DIG_SIZE) can be freely changed so +// long as the constraints mentioned above are met. + +#if defined(__x86_64__) +// 64-bit machine, using 32-bit storage for digits +typedef uint32_t mpz_dig_t; +typedef uint64_t mpz_dbl_dig_t; +typedef int64_t mpz_dbl_dig_signed_t; +#define MPZ_DIG_SIZE (32) +#else +// 32-bit machine, using 16-bit storage for digits typedef uint16_t mpz_dig_t; typedef uint32_t mpz_dbl_dig_t; typedef int32_t mpz_dbl_dig_signed_t; +#define MPZ_DIG_SIZE (16) +#endif + +#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1) +#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1) typedef struct _mpz_t { mp_uint_t neg : 1; @@ -36,10 +61,6 @@ typedef struct _mpz_t { mpz_dig_t *dig; } mpz_t; -#define MPZ_DIG_SIZE (15) // see mpn_div for why this needs to be at most 15 -#define MPZ_NUM_DIG_FOR_INT (sizeof(mp_int_t) * 8 / MPZ_DIG_SIZE + 1) -#define MPZ_NUM_DIG_FOR_LL (sizeof(long long) * 8 / MPZ_DIG_SIZE + 1) - // convenience macro to declare an mpz with a digit array from the stack, initialised by an integer #define MPZ_CONST_INT(z, val) mpz_t z; mpz_dig_t z ## _digits[MPZ_NUM_DIG_FOR_INT]; mpz_init_fixed_from_int(&z, z_digits, MPZ_NUM_DIG_FOR_INT, val); diff --git a/tests/basics/int_big_div.py b/tests/basics/int_big_div.py new file mode 100644 index 000000000..8dacf495d --- /dev/null +++ b/tests/basics/int_big_div.py @@ -0,0 +1,3 @@ +for lhs in (1000000000000000000000000, 10000000000100000000000000, 10012003400000000000000007, 12349083434598210349871029923874109871234789): + for rhs in range(1, 555): + print(lhs // rhs) -- GitLab