Skip to content

Commit d7d066c

Browse files
skirpichevvstinnerpicnixz
authored
gh-127936, PEP 757: Convert marshal module to use import/export API for ints (#128530)
Co-authored-by: Victor Stinner <[email protected]> Co-authored-by: Bénédikt Tran <[email protected]>
1 parent 1d485db commit d7d066c

File tree

1 file changed

+168
-75
lines changed

1 file changed

+168
-75
lines changed

Python/marshal.c

Lines changed: 168 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
240240
#define PyLong_MARSHAL_SHIFT 15
241241
#define PyLong_MARSHAL_BASE ((short)1 << PyLong_MARSHAL_SHIFT)
242242
#define PyLong_MARSHAL_MASK (PyLong_MARSHAL_BASE - 1)
243-
#if PyLong_SHIFT % PyLong_MARSHAL_SHIFT != 0
244-
#error "PyLong_SHIFT must be a multiple of PyLong_MARSHAL_SHIFT"
245-
#endif
246-
#define PyLong_MARSHAL_RATIO (PyLong_SHIFT / PyLong_MARSHAL_SHIFT)
247243

248244
#define W_TYPE(t, p) do { \
249245
w_byte((t) | flag, (p)); \
@@ -252,47 +248,106 @@ w_short_pstring(const void *s, Py_ssize_t n, WFILE *p)
252248
static PyObject *
253249
_PyMarshal_WriteObjectToString(PyObject *x, int version, int allow_code);
254250

251+
#define _r_digits(bitsize) \
252+
static void \
253+
_r_digits##bitsize(const uint ## bitsize ## _t *digits, Py_ssize_t n, \
254+
uint8_t negative, Py_ssize_t marshal_ratio, WFILE *p) \
255+
{ \
256+
/* set l to number of base PyLong_MARSHAL_BASE digits */ \
257+
Py_ssize_t l = (n - 1)*marshal_ratio; \
258+
uint ## bitsize ## _t d = digits[n - 1]; \
259+
\
260+
assert(marshal_ratio > 0); \
261+
assert(n >= 1); \
262+
assert(d != 0); /* a PyLong is always normalized */ \
263+
do { \
264+
d >>= PyLong_MARSHAL_SHIFT; \
265+
l++; \
266+
} while (d != 0); \
267+
if (l > SIZE32_MAX) { \
268+
p->depth--; \
269+
p->error = WFERR_UNMARSHALLABLE; \
270+
return; \
271+
} \
272+
w_long((long)(negative ? -l : l), p); \
273+
\
274+
for (Py_ssize_t i = 0; i < n - 1; i++) { \
275+
d = digits[i]; \
276+
for (Py_ssize_t j = 0; j < marshal_ratio; j++) { \
277+
w_short(d & PyLong_MARSHAL_MASK, p); \
278+
d >>= PyLong_MARSHAL_SHIFT; \
279+
} \
280+
assert(d == 0); \
281+
} \
282+
d = digits[n - 1]; \
283+
do { \
284+
w_short(d & PyLong_MARSHAL_MASK, p); \
285+
d >>= PyLong_MARSHAL_SHIFT; \
286+
} while (d != 0); \
287+
}
288+
_r_digits(16)
289+
_r_digits(32)
290+
#undef _r_digits
291+
255292
static void
256293
w_PyLong(const PyLongObject *ob, char flag, WFILE *p)
257294
{
258-
Py_ssize_t i, j, n, l;
259-
digit d;
260-
261295
W_TYPE(TYPE_LONG, p);
262296
if (_PyLong_IsZero(ob)) {
263297
w_long((long)0, p);
264298
return;
265299
}
266300

267-
/* set l to number of base PyLong_MARSHAL_BASE digits */
268-
n = _PyLong_DigitCount(ob);
269-
l = (n-1) * PyLong_MARSHAL_RATIO;
270-
d = ob->long_value.ob_digit[n-1];
271-
assert(d != 0); /* a PyLong is always normalized */
272-
do {
273-
d >>= PyLong_MARSHAL_SHIFT;
274-
l++;
275-
} while (d != 0);
276-
if (l > SIZE32_MAX) {
301+
PyLongExport long_export;
302+
303+
if (PyLong_Export((PyObject *)ob, &long_export) < 0) {
277304
p->depth--;
278305
p->error = WFERR_UNMARSHALLABLE;
279306
return;
280307
}
281-
w_long((long)(_PyLong_IsNegative(ob) ? -l : l), p);
308+
if (!long_export.digits) {
309+
int8_t sign = long_export.value < 0 ? -1 : 1;
310+
uint64_t abs_value = Py_ABS(long_export.value);
311+
uint64_t d = abs_value;
312+
long l = 0;
282313

283-
for (i=0; i < n-1; i++) {
284-
d = ob->long_value.ob_digit[i];
285-
for (j=0; j < PyLong_MARSHAL_RATIO; j++) {
314+
/* set l to number of base PyLong_MARSHAL_BASE digits */
315+
do {
316+
d >>= PyLong_MARSHAL_SHIFT;
317+
l += sign;
318+
} while (d);
319+
w_long(l, p);
320+
321+
d = abs_value;
322+
do {
286323
w_short(d & PyLong_MARSHAL_MASK, p);
287324
d >>= PyLong_MARSHAL_SHIFT;
288-
}
289-
assert (d == 0);
325+
} while (d);
326+
return;
290327
}
291-
d = ob->long_value.ob_digit[n-1];
292-
do {
293-
w_short(d & PyLong_MARSHAL_MASK, p);
294-
d >>= PyLong_MARSHAL_SHIFT;
295-
} while (d != 0);
328+
329+
const PyLongLayout *layout = PyLong_GetNativeLayout();
330+
Py_ssize_t marshal_ratio = layout->bits_per_digit/PyLong_MARSHAL_SHIFT;
331+
332+
/* must be a multiple of PyLong_MARSHAL_SHIFT */
333+
assert(layout->bits_per_digit % PyLong_MARSHAL_SHIFT == 0);
334+
assert(layout->bits_per_digit >= PyLong_MARSHAL_SHIFT);
335+
336+
/* other assumptions on PyLongObject internals */
337+
assert(layout->bits_per_digit <= 32);
338+
assert(layout->digits_order == -1);
339+
assert(layout->digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1));
340+
assert(layout->digit_size == 2 || layout->digit_size == 4);
341+
342+
if (layout->digit_size == 4) {
343+
_r_digits32(long_export.digits, long_export.ndigits,
344+
long_export.negative, marshal_ratio, p);
345+
}
346+
else {
347+
_r_digits16(long_export.digits, long_export.ndigits,
348+
long_export.negative, marshal_ratio, p);
349+
}
350+
PyLong_FreeExport(&long_export);
296351
}
297352

298353
static void
@@ -875,17 +930,62 @@ r_long64(RFILE *p)
875930
1 /* signed */);
876931
}
877932

933+
#define _w_digits(bitsize) \
934+
static int \
935+
_w_digits##bitsize(uint ## bitsize ## _t *digits, Py_ssize_t size, \
936+
Py_ssize_t marshal_ratio, \
937+
int shorts_in_top_digit, RFILE *p) \
938+
{ \
939+
uint ## bitsize ## _t d; \
940+
\
941+
assert(size >= 1); \
942+
for (Py_ssize_t i = 0; i < size - 1; i++) { \
943+
d = 0; \
944+
for (Py_ssize_t j = 0; j < marshal_ratio; j++) { \
945+
int md = r_short(p); \
946+
if (md < 0 || md > PyLong_MARSHAL_BASE) { \
947+
goto bad_digit; \
948+
} \
949+
d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT; \
950+
} \
951+
digits[i] = d; \
952+
} \
953+
\
954+
d = 0; \
955+
for (Py_ssize_t j = 0; j < shorts_in_top_digit; j++) { \
956+
int md = r_short(p); \
957+
if (md < 0 || md > PyLong_MARSHAL_BASE) { \
958+
goto bad_digit; \
959+
} \
960+
/* topmost marshal digit should be nonzero */ \
961+
if (md == 0 && j == shorts_in_top_digit - 1) { \
962+
PyErr_SetString(PyExc_ValueError, \
963+
"bad marshal data (unnormalized long data)"); \
964+
return -1; \
965+
} \
966+
d += (uint ## bitsize ## _t)md << j*PyLong_MARSHAL_SHIFT; \
967+
} \
968+
assert(!PyErr_Occurred()); \
969+
/* top digit should be nonzero, else the resulting PyLong won't be \
970+
normalized */ \
971+
digits[size - 1] = d; \
972+
return 0; \
973+
\
974+
bad_digit: \
975+
if (!PyErr_Occurred()) { \
976+
PyErr_SetString(PyExc_ValueError, \
977+
"bad marshal data (digit out of range in long)"); \
978+
} \
979+
return -1; \
980+
}
981+
_w_digits(32)
982+
_w_digits(16)
983+
#undef _w_digits
984+
878985
static PyObject *
879986
r_PyLong(RFILE *p)
880987
{
881-
PyLongObject *ob;
882-
long n, size, i;
883-
int j, md, shorts_in_top_digit;
884-
digit d;
885-
886-
n = r_long(p);
887-
if (n == 0)
888-
return (PyObject *)_PyLong_New(0);
988+
long n = r_long(p);
889989
if (n == -1 && PyErr_Occurred()) {
890990
return NULL;
891991
}
@@ -895,51 +995,44 @@ r_PyLong(RFILE *p)
895995
return NULL;
896996
}
897997

898-
size = 1 + (Py_ABS(n) - 1) / PyLong_MARSHAL_RATIO;
899-
shorts_in_top_digit = 1 + (Py_ABS(n) - 1) % PyLong_MARSHAL_RATIO;
900-
ob = _PyLong_New(size);
901-
if (ob == NULL)
902-
return NULL;
998+
const PyLongLayout *layout = PyLong_GetNativeLayout();
999+
Py_ssize_t marshal_ratio = layout->bits_per_digit/PyLong_MARSHAL_SHIFT;
9031000

904-
_PyLong_SetSignAndDigitCount(ob, n < 0 ? -1 : 1, size);
1001+
/* must be a multiple of PyLong_MARSHAL_SHIFT */
1002+
assert(layout->bits_per_digit % PyLong_MARSHAL_SHIFT == 0);
1003+
assert(layout->bits_per_digit >= PyLong_MARSHAL_SHIFT);
9051004

906-
for (i = 0; i < size-1; i++) {
907-
d = 0;
908-
for (j=0; j < PyLong_MARSHAL_RATIO; j++) {
909-
md = r_short(p);
910-
if (md < 0 || md > PyLong_MARSHAL_BASE)
911-
goto bad_digit;
912-
d += (digit)md << j*PyLong_MARSHAL_SHIFT;
913-
}
914-
ob->long_value.ob_digit[i] = d;
1005+
/* other assumptions on PyLongObject internals */
1006+
assert(layout->bits_per_digit <= 32);
1007+
assert(layout->digits_order == -1);
1008+
assert(layout->digit_endianness == (PY_LITTLE_ENDIAN ? -1 : 1));
1009+
assert(layout->digit_size == 2 || layout->digit_size == 4);
1010+
1011+
Py_ssize_t size = 1 + (Py_ABS(n) - 1) / marshal_ratio;
1012+
1013+
assert(size >= 1);
1014+
1015+
int shorts_in_top_digit = 1 + (Py_ABS(n) - 1) % marshal_ratio;
1016+
void *digits;
1017+
PyLongWriter *writer = PyLongWriter_Create(n < 0, size, &digits);
1018+
1019+
if (writer == NULL) {
1020+
return NULL;
9151021
}
9161022

917-
d = 0;
918-
for (j=0; j < shorts_in_top_digit; j++) {
919-
md = r_short(p);
920-
if (md < 0 || md > PyLong_MARSHAL_BASE)
921-
goto bad_digit;
922-
/* topmost marshal digit should be nonzero */
923-
if (md == 0 && j == shorts_in_top_digit - 1) {
924-
Py_DECREF(ob);
925-
PyErr_SetString(PyExc_ValueError,
926-
"bad marshal data (unnormalized long data)");
927-
return NULL;
928-
}
929-
d += (digit)md << j*PyLong_MARSHAL_SHIFT;
1023+
int ret;
1024+
1025+
if (layout->digit_size == 4) {
1026+
ret = _w_digits32(digits, size, marshal_ratio, shorts_in_top_digit, p);
9301027
}
931-
assert(!PyErr_Occurred());
932-
/* top digit should be nonzero, else the resulting PyLong won't be
933-
normalized */
934-
ob->long_value.ob_digit[size-1] = d;
935-
return (PyObject *)ob;
936-
bad_digit:
937-
Py_DECREF(ob);
938-
if (!PyErr_Occurred()) {
939-
PyErr_SetString(PyExc_ValueError,
940-
"bad marshal data (digit out of range in long)");
1028+
else {
1029+
ret = _w_digits16(digits, size, marshal_ratio, shorts_in_top_digit, p);
1030+
}
1031+
if (ret < 0) {
1032+
PyLongWriter_Discard(writer);
1033+
return NULL;
9411034
}
942-
return NULL;
1035+
return PyLongWriter_Finish(writer);
9431036
}
9441037

9451038
static double

0 commit comments

Comments
 (0)