Skip to content

[3.7] Revert bpo-39576: Prevent memory error for overly optimistic precisions (GH-20748) #20748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5476,41 +5476,6 @@ def __abs__(self):
self.assertEqual(Decimal.from_float(cls(101.1)),
Decimal.from_float(101.1))

def test_maxcontext_exact_arith(self):

# Make sure that exact operations do not raise MemoryError due
# to huge intermediate values when the context precision is very
# large.

# The following functions fill the available precision and are
# therefore not suitable for large precisions (by design of the
# specification).
MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
'logical_and', 'logical_or', 'logical_xor',
'next_toward', 'rotate', 'shift']

Decimal = C.Decimal
Context = C.Context
localcontext = C.localcontext

# Here only some functions that are likely candidates for triggering a
# MemoryError are tested. deccheck.py has an exhaustive test.
maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
with localcontext(maxcontext):
self.assertEqual(Decimal(0).exp(), 1)
self.assertEqual(Decimal(1).ln(), 0)
self.assertEqual(Decimal(1).log10(), 0)
self.assertEqual(Decimal(10**2).log10(), 2)
self.assertEqual(Decimal(10**223).log10(), 223)
self.assertEqual(Decimal(10**19).logb(), 19)
self.assertEqual(Decimal(4).sqrt(), 2)
self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
self.assertEqual(divmod(Decimal(10), 3), (3, 1))
self.assertEqual(Decimal(10) // 3, 3)
self.assertEqual(Decimal(4) / 2, 2)
self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))


@requires_docstrings
@unittest.skipUnless(C, "test requires C version")
class SignatureTest(unittest.TestCase):
Expand Down
77 changes: 3 additions & 74 deletions Modules/_decimal/libmpdec/mpdecimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -3781,43 +3781,6 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
const mpd_context_t *ctx, uint32_t *status)
{
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);

if (*status & MPD_Malloc_error) {
/* Inexact quotients (the usual case) fill the entire context precision,
* which can lead to malloc() failures for very high precisions. Retry
* the operation with a lower precision in case the result is exact.
*
* We need an upper bound for the number of digits of a_coeff / b_coeff
* when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
* terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
* bound.
*
* 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
* The largest amount of digits is generated if b_coeff' is a power of 2 or
* a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
*
* We arrive at a total upper bound:
*
* maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
* a->digits + log2(b_coeff) =
* a->digits + log10(b_coeff) / log10(2) <=
* a->digits + b->digits * 4;
*/
uint32_t workstatus = 0;
mpd_context_t workctx = *ctx;
workctx.prec = a->digits + b->digits * 4;
if (workctx.prec >= ctx->prec) {
return; /* No point in retrying, keep the original error. */
}

_mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
*status = 0;
return;
}

mpd_seterror(q, *status, status);
}
}

/* Internal function. */
Expand Down Expand Up @@ -7739,9 +7702,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
/* END LIBMPDEC_ONLY */

/* Algorithm from decimal.py */
static void
_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
uint32_t *status)
void
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
uint32_t *status)
{
mpd_context_t maxcontext;
MPD_NEW_STATIC(c,0,0,0,0);
Expand Down Expand Up @@ -7873,40 +7836,6 @@ _mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
goto out;
}

void
mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
uint32_t *status)
{
_mpd_qsqrt(result, a, ctx, status);

if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
/* The above conditions can occur at very high context precisions
* if intermediate values get too large. Retry the operation with
* a lower context precision in case the result is exact.
*
* If the result is exact, an upper bound for the number of digits
* is the number of digits in the input.
*
* NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
*/
uint32_t workstatus = 0;
mpd_context_t workctx = *ctx;
workctx.prec = a->digits;

if (workctx.prec >= ctx->prec) {
return; /* No point in repeating this, keep the original error. */
}

_mpd_qsqrt(result, a, &workctx, &workstatus);
if (workstatus == 0) {
*status = 0;
return;
}

mpd_seterror(result, *status, status);
}
}


/******************************************************************************/
/* Base conversions */
Expand Down
Loading