From 460b0863334fb143667ae6aa1b5f0bd2bbdf8282 Mon Sep 17 00:00:00 2001 From: Damien George <damien.p.george@gmail.com> Date: Mon, 9 May 2016 17:21:42 +0100 Subject: [PATCH] py/mpz: Fix mpn_div so that it doesn't modify memory of denominator. Previous to this patch bignum division and modulo would temporarily modify the RHS argument to the operation (eg x/y would modify y), but on return the RHS would be restored to its original value. This is not allowed because arguments to binary operations are const, and in particular might live in ROM. The modification was to normalise the arg (and then unnormalise before returning), and this patch makes it so the normalisation is done on the fly and the arg is now accessed as read-only. This change doesn't increase the order complexity of the operation, and actually reduces code size. --- py/mpz.c | 57 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/py/mpz.c b/py/mpz.c index 3fb2548c4..bb7647956 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d assumes num_dig has enough memory to be extended by 1 digit assumes quo_dig has enough memory (as many digits as num) assumes quo_dig is filled with zeros - modifies den_dig memory, but restors it to original state at end */ - -STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) { +STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) { mpz_dig_t *orig_num_dig = num_dig; mpz_dig_t *orig_quo_dig = quo_dig; mpz_dig_t norm_shift = 0; @@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } } + // We need to normalise the denominator (leading bit of leading digit is 1) + // so that the division routine works. Since the denominator memory is + // read-only we do the normalisation on the fly, each time a digit of the + // denominator is needed. We need to know is how many bits to shift by. + // count number of leading zeros in leading digit of denominator { mpz_dig_t d = den_dig[den_len - 1]; @@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } } - // normalise denomenator (leading bit of leading digit is 1) - for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) { - mpz_dig_t d = *den; - *den = ((d << norm_shift) | carry) & DIG_MASK; - carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift); - } - // now need to shift numerator by same amount as denominator // first, increase length of numerator in case we need more room to shift num_dig[*num_len] = 0; @@ -505,7 +501,10 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } // cache the leading digit of the denominator - lead_den_digit = den_dig[den_len - 1]; + lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift; + if (den_len >= 2) { + lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift); + } // point num_dig to last digit in numerator num_dig += *num_len - 1; @@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // round up). if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) { + const mpz_dig_t *d = den_dig; + mpz_dbl_dig_t d_norm = 0; mpz_dbl_dig_signed_t borrow = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 *n = borrow & DIG_MASK; borrow >>= DIG_SIZE; } @@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // adjust quotient if it is too big for (; borrow != 0; --quo) { + d = den_dig; + d_norm = 0; mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); *n = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, borrow += carry; } } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2 + const mpz_dig_t *d = den_dig; + mpz_dbl_dig_t d_norm = 0; mpz_dbl_dig_t borrow = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d); + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); if (x >= *n || *n - x <= borrow) { borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n; *n = (-borrow) & DIG_MASK; @@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // adjust quotient if it is too big for (; borrow != 0; --quo) { + d = den_dig; + d_norm = 0; mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); *n = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -614,13 +625,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, --(*num_len); } - // unnormalise denomenator - for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) { - mpz_dig_t d = *den; - *den = ((d >> norm_shift) | carry) & DIG_MASK; - carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift); - } - // unnormalise numerator (remainder now) for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) { mpz_dig_t n = *num; @@ -1506,7 +1510,6 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m dest_quo->len = 0; mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary? mpz_set(dest_rem, lhs); - //rhs->dig[rhs->len] = 0; mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len); // check signs and do Python style modulo -- GitLab