From 271d18eb08ec488ee45f8e6cd852e8236074f082 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Sat, 25 Apr 2015 23:16:39 +0100
Subject: [PATCH] py: Support conversion of bignum to bytes.

This gets int.to_bytes working for bignum, and also struct.pack with 'q'
and 'Q' args on 32-bit machines.

Addresses issue #1155.
---
 py/binary.c               | 10 +++++++---
 py/mpz.c                  | 34 ++++++++++++++++++++++++++++++++++
 py/mpz.h                  |  1 +
 py/objint.c               | 18 +++++++++---------
 py/objint.h               |  1 +
 py/objint_longlong.c      | 18 ++++++++++++++++++
 py/objint_mpz.c           |  6 ++++++
 tests/basics/int_bytes.py |  1 +
 tests/basics/struct1.py   | 15 ++++++++++++---
 9 files changed, 89 insertions(+), 15 deletions(-)

diff --git a/py/binary.c b/py/binary.c
index 927a42640..8b5c05ab3 100644
--- a/py/binary.c
+++ b/py/binary.c
@@ -32,6 +32,7 @@
 
 #include "py/binary.h"
 #include "py/smallint.h"
+#include "py/objint.h"
 
 // Helpers to work with binary-encoded data
 
@@ -282,10 +283,13 @@ void mp_binary_set_val(char struct_type, char val_type, mp_obj_t val_in, byte **
         }
 #endif
         default:
-            // we handle large ints here by calling the truncated accessor
+            #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE
             if (MP_OBJ_IS_TYPE(val_in, &mp_type_int)) {
-                val = mp_obj_int_get_truncated(val_in);
-            } else {
+                mp_obj_int_to_bytes_impl(val_in, struct_type == '>', size, p);
+                return;
+            } else
+            #endif
+            {
                 val = mp_obj_get_int(val_in);
             }
     }
diff --git a/py/mpz.c b/py/mpz.c
index 241fa79be..3c20023bc 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -1425,6 +1425,40 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
     return true;
 }
 
+// writes at most len bytes to buf (so buf should be zeroed before calling)
+void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf) {
+    byte *b = buf;
+    if (big_endian) {
+        b += len;
+    }
+    mpz_dig_t *zdig = z->dig;
+    int bits = 0;
+    mpz_dbl_dig_t d = 0;
+    mpz_dbl_dig_t carry = 1;
+    for (mp_uint_t zlen = z->len; zlen > 0; --zlen) {
+        bits += DIG_SIZE;
+        d = (d << DIG_SIZE) | *zdig++;
+        for (; bits >= 8; bits -= 8, d >>= 8) {
+            mpz_dig_t val = d;
+            if (z->neg) {
+                d = (~d & 0xff) + carry;
+                carry = d >> 8;
+            }
+            if (big_endian) {
+                *--b = val;
+                if (b == buf) {
+                    return;
+                }
+            } else {
+                *b++ = val;
+                if (b == buf + len) {
+                    return;
+                }
+            }
+        }
+    }
+}
+
 #if MICROPY_PY_BUILTINS_FLOAT
 mp_float_t mpz_as_float(const mpz_t *i) {
     mp_float_t val = 0;
diff --git a/py/mpz.h b/py/mpz.h
index 71649aa7f..b00d2b655 100644
--- a/py/mpz.h
+++ b/py/mpz.h
@@ -125,6 +125,7 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
 mp_int_t mpz_hash(const mpz_t *z);
 bool mpz_as_int_checked(const mpz_t *z, mp_int_t *value);
 bool mpz_as_uint_checked(const mpz_t *z, mp_uint_t *value);
+void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf);
 #if MICROPY_PY_BUILTINS_FLOAT
 mp_float_t mpz_as_float(const mpz_t *z);
 #endif
diff --git a/py/objint.c b/py/objint.c
index 64faed636..7c527d4ae 100644
--- a/py/objint.c
+++ b/py/objint.c
@@ -35,6 +35,7 @@
 #include "py/objstr.h"
 #include "py/runtime0.h"
 #include "py/runtime.h"
+#include "py/binary.h"
 
 #if MICROPY_PY_BUILTINS_FLOAT
 #include <math.h>
@@ -398,12 +399,10 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(int_from_bytes_fun_obj, 2, 3, int_fro
 STATIC MP_DEFINE_CONST_CLASSMETHOD_OBJ(int_from_bytes_obj, (const mp_obj_t)&int_from_bytes_fun_obj);
 
 STATIC mp_obj_t int_to_bytes(mp_uint_t n_args, const mp_obj_t *args) {
-    // TODO: Support long ints
     // TODO: Support byteorder param (assumes 'little')
     // TODO: Support signed param (assumes signed=False)
     (void)n_args;
 
-    mp_int_t val = mp_obj_int_get_checked(args[0]);
     mp_uint_t len = MP_OBJ_SMALL_INT_VALUE(args[1]);
 
     vstr_t vstr;
@@ -411,13 +410,14 @@ STATIC mp_obj_t int_to_bytes(mp_uint_t n_args, const mp_obj_t *args) {
     byte *data = (byte*)vstr.buf;
     memset(data, 0, len);
 
-    if (MP_ENDIANNESS_LITTLE) {
-        memcpy(data, &val, len < sizeof(mp_int_t) ? len : sizeof(mp_int_t));
-    } else {
-        while (len--) {
-            *data++ = val;
-            val >>= 8;
-        }
+    #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE
+    if (!MP_OBJ_IS_SMALL_INT(args[0])) {
+        mp_obj_int_to_bytes_impl(args[0], false, len, data);
+    } else
+    #endif
+    {
+        mp_int_t val = MP_OBJ_SMALL_INT_VALUE(args[0]);
+        mp_binary_set_int(MIN((size_t)len, sizeof(val)), false, data, val);
     }
 
     return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
diff --git a/py/objint.h b/py/objint.h
index daeb3c499..09cf7c86d 100644
--- a/py/objint.h
+++ b/py/objint.h
@@ -56,6 +56,7 @@ char *mp_obj_int_formatted(char **buf, mp_uint_t *buf_size, mp_uint_t *fmt_size,
 char *mp_obj_int_formatted_impl(char **buf, mp_uint_t *buf_size, mp_uint_t *fmt_size, mp_const_obj_t self_in,
                                 int base, const char *prefix, char base_char, char comma);
 mp_int_t mp_obj_int_hash(mp_obj_t self_in);
+void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf);
 bool mp_obj_int_is_positive(mp_obj_t self_in);
 mp_obj_t mp_obj_int_abs(mp_obj_t self_in);
 mp_obj_t mp_obj_int_unary_op(mp_uint_t op, mp_obj_t o_in);
diff --git a/py/objint_longlong.c b/py/objint_longlong.c
index 837889704..5b2c6d3f5 100644
--- a/py/objint_longlong.c
+++ b/py/objint_longlong.c
@@ -63,6 +63,24 @@ mp_int_t mp_obj_int_hash(mp_obj_t self_in) {
     return self->val;
 }
 
+void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf) {
+    assert(MP_OBJ_IS_TYPE(self_in, &mp_type_int));
+    mp_obj_int_t *self = self_in;
+    long long val = self->val;
+    if (big_endian) {
+        byte *b = buf + len;
+        while (b > buf) {
+            *--b = val;
+            val >>= 8;
+        }
+    } else {
+        for (; len > 0; --len) {
+            *buf++ = val;
+            val >>= 8;
+        }
+    }
+}
+
 bool mp_obj_int_is_positive(mp_obj_t self_in) {
     if (MP_OBJ_IS_SMALL_INT(self_in)) {
         return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0;
diff --git a/py/objint_mpz.c b/py/objint_mpz.c
index 2746f4dff..369e5af32 100644
--- a/py/objint_mpz.c
+++ b/py/objint_mpz.c
@@ -104,6 +104,12 @@ mp_int_t mp_obj_int_hash(mp_obj_t self_in) {
     return mpz_hash(&self->mpz);
 }
 
+void mp_obj_int_to_bytes_impl(mp_obj_t self_in, bool big_endian, mp_uint_t len, byte *buf) {
+    assert(MP_OBJ_IS_TYPE(self_in, &mp_type_int));
+    mp_obj_int_t *self = self_in;
+    mpz_as_bytes(&self->mpz, big_endian, len, buf);
+}
+
 bool mp_obj_int_is_positive(mp_obj_t self_in) {
     if (MP_OBJ_IS_SMALL_INT(self_in)) {
         return MP_OBJ_SMALL_INT_VALUE(self_in) >= 0;
diff --git a/tests/basics/int_bytes.py b/tests/basics/int_bytes.py
index 45965ed46..2f468da44 100644
--- a/tests/basics/int_bytes.py
+++ b/tests/basics/int_bytes.py
@@ -1,6 +1,7 @@
 print((10).to_bytes(1, "little"))
 print((111111).to_bytes(4, "little"))
 print((100).to_bytes(10, "little"))
+print((2**64).to_bytes(9, "little"))
 print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little"))
 print(int.from_bytes(b"\x01\0\0\0\0\0\0\0", "little"))
 print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little"))
diff --git a/tests/basics/struct1.py b/tests/basics/struct1.py
index 09ecd20a6..c473fc0b0 100644
--- a/tests/basics/struct1.py
+++ b/tests/basics/struct1.py
@@ -30,9 +30,18 @@ print(v == (10, 100, 200, 300))
 print(struct.pack("<I", 2**32 - 1))
 print(struct.pack("<I", 0xffffffff))
 
-# fails on 32-bit machine
-#print(struct.pack("<Q", 2**64 - 1))
-#print(struct.pack("<Q", 0xffffffffffffffff))
+# long long ints
+print(struct.pack("<Q", 2**64 - 1))
+print(struct.pack("<Q", 0xffffffffffffffff))
+print(struct.pack("<q", -1))
+print(struct.pack("<Q", 1234567890123456789))
+print(struct.pack("<q", -1234567890123456789))
+print(struct.pack(">Q", 1234567890123456789))
+print(struct.pack(">q", -1234567890123456789))
+print(struct.unpack("<Q", b"\x12\x34\x56\x78\x90\x12\x34\x56"))
+print(struct.unpack(">Q", b"\x12\x34\x56\x78\x90\x12\x34\x56"))
+print(struct.unpack("<q", b"\x12\x34\x56\x78\x90\x12\x34\xf6"))
+print(struct.unpack(">q", b"\xf2\x34\x56\x78\x90\x12\x34\x56"))
 
 # check maximum unpack
 print(struct.unpack("<I", b"\xff\xff\xff\xff"))
-- 
GitLab