Skip to content

Commit c6f9554

Browse files
bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581) (#18585)
(cherry picked from commit 90930e6) Authored-by: Stefan Krah <[email protected]>
1 parent 736e0ea commit c6f9554

File tree

3 files changed

+245
-6
lines changed

3 files changed

+245
-6
lines changed

Lib/test/test_decimal.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5476,6 +5476,41 @@ def __abs__(self):
54765476
self.assertEqual(Decimal.from_float(cls(101.1)),
54775477
Decimal.from_float(101.1))
54785478

5479+
def test_maxcontext_exact_arith(self):
5480+
5481+
# Make sure that exact operations do not raise MemoryError due
5482+
# to huge intermediate values when the context precision is very
5483+
# large.
5484+
5485+
# The following functions fill the available precision and are
5486+
# therefore not suitable for large precisions (by design of the
5487+
# specification).
5488+
MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
5489+
'logical_and', 'logical_or', 'logical_xor',
5490+
'next_toward', 'rotate', 'shift']
5491+
5492+
Decimal = C.Decimal
5493+
Context = C.Context
5494+
localcontext = C.localcontext
5495+
5496+
# Here only some functions that are likely candidates for triggering a
5497+
# MemoryError are tested. deccheck.py has an exhaustive test.
5498+
maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
5499+
with localcontext(maxcontext):
5500+
self.assertEqual(Decimal(0).exp(), 1)
5501+
self.assertEqual(Decimal(1).ln(), 0)
5502+
self.assertEqual(Decimal(1).log10(), 0)
5503+
self.assertEqual(Decimal(10**2).log10(), 2)
5504+
self.assertEqual(Decimal(10**223).log10(), 223)
5505+
self.assertEqual(Decimal(10**19).logb(), 19)
5506+
self.assertEqual(Decimal(4).sqrt(), 2)
5507+
self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
5508+
self.assertEqual(divmod(Decimal(10), 3), (3, 1))
5509+
self.assertEqual(Decimal(10) // 3, 3)
5510+
self.assertEqual(Decimal(4) / 2, 2)
5511+
self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
5512+
5513+
54795514
@requires_docstrings
54805515
@unittest.skipUnless(C, "test requires C version")
54815516
class SignatureTest(unittest.TestCase):

Modules/_decimal/libmpdec/mpdecimal.c

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
37813781
const mpd_context_t *ctx, uint32_t *status)
37823782
{
37833783
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
3784+
3785+
if (*status & MPD_Malloc_error) {
3786+
/* Inexact quotients (the usual case) fill the entire context precision,
3787+
* which can lead to malloc() failures for very high precisions. Retry
3788+
* the operation with a lower precision in case the result is exact.
3789+
*
3790+
* We need an upper bound for the number of digits of a_coeff / b_coeff
3791+
* when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
3792+
* terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
3793+
* bound.
3794+
*
3795+
* 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
3796+
* The largest amount of digits is generated if b_coeff' is a power of 2 or
3797+
* a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
3798+
*
3799+
* We arrive at a total upper bound:
3800+
*
3801+
* maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
3802+
* a->digits + log2(b_coeff) =
3803+
* a->digits + log10(b_coeff) / log10(2) <=
3804+
* a->digits + b->digits * 4;
3805+
*/
3806+
uint32_t workstatus = 0;
3807+
mpd_context_t workctx = *ctx;
3808+
workctx.prec = a->digits + b->digits * 4;
3809+
if (workctx.prec >= ctx->prec) {
3810+
return; /* No point in retrying, keep the original error. */
3811+
}
3812+
3813+
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
3814+
if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
3815+
*status = 0;
3816+
return;
3817+
}
3818+
3819+
mpd_seterror(q, *status, status);
3820+
}
37843821
}
37853822

37863823
/* Internal function. */
@@ -7702,9 +7739,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
77027739
/* END LIBMPDEC_ONLY */
77037740

77047741
/* Algorithm from decimal.py */
7705-
void
7706-
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
7707-
uint32_t *status)
7742+
static void
7743+
_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
7744+
uint32_t *status)
77087745
{
77097746
mpd_context_t maxcontext;
77107747
MPD_NEW_STATIC(c,0,0,0,0);
@@ -7836,6 +7873,40 @@ mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
78367873
goto out;
78377874
}
78387875

7876+
void
7877+
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
7878+
uint32_t *status)
7879+
{
7880+
_mpd_qsqrt(result, a, ctx, status);
7881+
7882+
if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
7883+
/* The above conditions can occur at very high context precisions
7884+
* if intermediate values get too large. Retry the operation with
7885+
* a lower context precision in case the result is exact.
7886+
*
7887+
* If the result is exact, an upper bound for the number of digits
7888+
* is the number of digits in the input.
7889+
*
7890+
* NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
7891+
*/
7892+
uint32_t workstatus = 0;
7893+
mpd_context_t workctx = *ctx;
7894+
workctx.prec = a->digits;
7895+
7896+
if (workctx.prec >= ctx->prec) {
7897+
return; /* No point in repeating this, keep the original error. */
7898+
}
7899+
7900+
_mpd_qsqrt(result, a, &workctx, &workstatus);
7901+
if (workstatus == 0) {
7902+
*status = 0;
7903+
return;
7904+
}
7905+
7906+
mpd_seterror(result, *status, status);
7907+
}
7908+
}
7909+
78397910

78407911
/******************************************************************************/
78417912
/* Base conversions */

0 commit comments

Comments
 (0)