Skip to content

Commit 614afcf

Browse files
committed
Add softmax and log_softmax to math module
1 parent 333f7f3 commit 614afcf

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
125125
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
126126
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
127127
- Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344)
128+
- `softmax` and `log_softmax` functions added to `math` module (see [#5279](https://github.com/pymc-devs/pymc/pull/5279)).
128129
- ...
129130

130131

pymc/math.py

+16
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,22 @@ def invlogit(x, eps=None):
211211
return at.sigmoid(x)
212212

213213

214+
def softmax(x, axis=None):
215+
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
216+
# drops that warning
217+
with warnings.catch_warnings():
218+
warnings.simplefilter("ignore", UserWarning)
219+
return at.nnet.softmax(x, axis=axis)
220+
221+
222+
def log_softmax(x, axis=None):
223+
# Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
224+
# drops that warning
225+
with warnings.catch_warnings():
226+
warnings.simplefilter("ignore", UserWarning)
227+
return at.nnet.logsoftmax(x, axis=axis)
228+
229+
214230
def logbern(log_p):
215231
if np.isnan(log_p):
216232
raise FloatingPointError("log_p can't be nan.")

pymc/tests/test_math.py

+23
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
kronecker,
3131
log1mexp,
3232
log1mexp_numpy,
33+
log_softmax,
3334
logdet,
3435
logdiffexp,
3536
logdiffexp_numpy,
3637
probit,
38+
softmax,
3739
)
3840
from pymc.tests.helpers import SeededTest, verify_grad
3941

@@ -265,3 +267,24 @@ def test_invlogit_deprecation_warning():
265267
assert not record
266268

267269
assert np.isclose(res, res_zero_eps)
270+
271+
272+
@pytest.mark.parametrize(
273+
"aesara_function, pymc_wrapper",
274+
[
275+
(at.nnet.softmax, softmax),
276+
(at.nnet.logsoftmax, log_softmax),
277+
],
278+
)
279+
def test_softmax_logsoftmax_no_warnings(aesara_function, pymc_wrapper):
280+
"""Test that wrappers for aesara functions do not issue Warnings"""
281+
282+
vector = at.vector("vector")
283+
with pytest.warns(None) as record:
284+
aesara_function(vector)
285+
warnings = {warning.category for warning in record.list}
286+
assert warnings == {UserWarning, FutureWarning}
287+
288+
with pytest.warns(None) as record:
289+
pymc_wrapper(vector)
290+
assert not record

0 commit comments

Comments
 (0)