From 614afcf9fce5e0e6edb6ab8a92daa9d5c91a89c0 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 22 Dec 2021 11:08:20 +0100 Subject: [PATCH 1/2] Add softmax and log_softmax to math module --- RELEASE-NOTES.md | 1 + pymc/math.py | 16 ++++++++++++++++ pymc/tests/test_math.py | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index c90480f5ea..d5d44a6a27 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -125,6 +125,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098). - Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169) - Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344) +- `softmax` and `log_softmax` functions added to `math` module (see [#5279](https://github.com/pymc-devs/pymc/pull/5279)). - ... diff --git a/pymc/math.py b/pymc/math.py index ab39527c1e..6ba6fcab4e 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -211,6 +211,22 @@ def invlogit(x, eps=None): return at.sigmoid(x) +def softmax(x, axis=None): + # Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara + # drops that warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + return at.nnet.softmax(x, axis=axis) + + +def log_softmax(x, axis=None): + # Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara + # drops that warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + return at.nnet.logsoftmax(x, axis=axis) + + def logbern(log_p): if np.isnan(log_p): raise FloatingPointError("log_p can't be nan.") diff --git a/pymc/tests/test_math.py b/pymc/tests/test_math.py index dda0a09e03..2c27296d64 100644 --- a/pymc/tests/test_math.py +++ b/pymc/tests/test_math.py @@ -30,10 +30,12 @@ kronecker, log1mexp, log1mexp_numpy, + log_softmax, logdet, logdiffexp, logdiffexp_numpy, probit, + softmax, ) from pymc.tests.helpers import SeededTest, verify_grad @@ -265,3 +267,24 @@ def test_invlogit_deprecation_warning(): assert not record assert np.isclose(res, res_zero_eps) + + +@pytest.mark.parametrize( + "aesara_function, pymc_wrapper", + [ + (at.nnet.softmax, softmax), + (at.nnet.logsoftmax, log_softmax), + ], +) +def test_softmax_logsoftmax_no_warnings(aesara_function, pymc_wrapper): + """Test that wrappers for aesara functions do not issue Warnings""" + + vector = at.vector("vector") + with pytest.warns(None) as record: + aesara_function(vector) + warnings = {warning.category for warning in record.list} + assert warnings == {UserWarning, FutureWarning} + + with pytest.warns(None) as record: + pymc_wrapper(vector) + assert not record From 8af7d80035349f2c482d808d566176d67f88e266 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 22 Dec 2021 11:08:50 +0100 Subject: [PATCH 2/2] Remove custom implementation of softmax in metropolis.py --- pymc/step_methods/metropolis.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index c4d7f23d8f..c0823947a5 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -16,6 +16,7 @@ import numpy as np import numpy.random as nr import scipy.linalg +import scipy.special from aesara.graph.fg import MissingInputError from aesara.tensor.random.basic import BernoulliRV, CategoricalRV @@ -608,7 +609,7 @@ def metropolis_proportional(self, q, logp, logp_curr, dim, k): if candidate_cat != given_cat: q.data[dim] = candidate_cat log_probs[candidate_cat] = logp(q) - probs = softmax(log_probs) + probs = scipy.special.softmax(log_probs, axis=0) prob_curr, probs[given_cat] = probs[given_cat], 0.0 probs /= 1.0 - prob_curr proposed_cat = nr.choice(candidates, p=probs) @@ -995,11 +996,6 @@ def sample_except(limit, excluded): return candidate -def softmax(x): - e_x = np.exp(x - np.max(x)) - return e_x / np.sum(e_x, axis=0) - - def delta_logp(point, logp, vars, shared): [logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)