diff --git a/py/modstruct.c b/py/modstruct.c
index eabc951aefbf8f20fa447a5a0e252b742de234f9..2016add17e174531c71de69a52e57855abf8b512 100644
--- a/py/modstruct.c
+++ b/py/modstruct.c
@@ -103,30 +103,24 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) {
     char fmt_type = get_fmt_type(&fmt);
     mp_uint_t size;
     for (size = 0; *fmt; fmt++) {
-        mp_uint_t align = 1;
         mp_uint_t cnt = 1;
         if (unichar_isdigit(*fmt)) {
             cnt = get_fmt_num(&fmt);
         }
 
-        mp_uint_t sz = 0;
         if (*fmt == 's') {
-            sz = cnt;
-            cnt = 1;
-        }
-
-        while (cnt--) {
-            // If we already have size for 's' case, don't set it again
-            if (sz == 0) {
-                sz = (mp_uint_t)mp_binary_get_size(fmt_type, *fmt, &align);
-            }
+            size += cnt;
+        } else {
+            mp_uint_t align;
+            size_t sz = mp_binary_get_size(fmt_type, *fmt, &align);
             if (sz == 0) {
                 nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "unsupported format"));
             }
-            // Apply alignment
-            size = (size + align - 1) & ~(align - 1);
-            size += sz;
-            sz = 0;
+            while (cnt--) {
+                // Apply alignment
+                size = (size + align - 1) & ~(align - 1);
+                size += sz;
+            }
         }
     }
     return MP_OBJ_NEW_SMALL_INT(size);
diff --git a/tests/basics/struct2.py b/tests/basics/struct2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f438bb55d235f4b091861d281e72817210b60e95
--- /dev/null
+++ b/tests/basics/struct2.py
@@ -0,0 +1,28 @@
+# test ustruct with a count specified before the type
+
+try:
+    import ustruct as struct
+except:
+    import struct
+
+print(struct.calcsize('0s'))
+print(struct.unpack('0s', b''))
+print(struct.pack('0s', b'123'))
+
+print(struct.calcsize('2s'))
+print(struct.unpack('2s', b'12'))
+print(struct.pack('2s', b'123'))
+
+print(struct.calcsize('2H'))
+print(struct.unpack('<2H', b'1234'))
+print(struct.pack('<2H', 258, 515))
+
+print(struct.calcsize('0s1s0H2H'))
+print(struct.unpack('<0s1s0H2H', b'01234'))
+print(struct.pack('<0s1s0H2H', b'abc', b'abc', 258, 515))
+
+# check that zero of an unknown type raises an exception
+try:
+    struct.calcsize('0z')
+except:
+    print('Exception')