File tree 3 files changed +40
-0
lines changed
3 files changed +40
-0
lines changed Original file line number Diff line number Diff line change @@ -125,6 +125,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
125
125
- ` pm.Data ` now passes additional kwargs to ` aesara.shared ` /` at.as_tensor ` . [ #5098 ] ( https://github.com/pymc-devs/pymc/pull/5098 ) .
126
126
- Univariate censored distributions are now available via ` pm.Censored ` . [ #5169 ] ( https://github.com/pymc-devs/pymc/pull/5169 )
127
127
- Nested models now inherit the parent model's coordinates. [ #5344 ] ( https://github.com/pymc-devs/pymc/pull/5344 )
128
+ - ` softmax ` and ` log_softmax ` functions added to ` math ` module (see [ #5279 ] ( https://github.com/pymc-devs/pymc/pull/5279 ) ).
128
129
- ...
129
130
130
131
Original file line number Diff line number Diff line change @@ -211,6 +211,22 @@ def invlogit(x, eps=None):
211
211
return at .sigmoid (x )
212
212
213
213
214
+ def softmax (x , axis = None ):
215
+ # Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
216
+ # drops that warning
217
+ with warnings .catch_warnings ():
218
+ warnings .simplefilter ("ignore" , UserWarning )
219
+ return at .nnet .softmax (x , axis = axis )
220
+
221
+
222
+ def log_softmax (x , axis = None ):
223
+ # Ignore vector case UserWarning issued by Aesara. This can be removed once Aesara
224
+ # drops that warning
225
+ with warnings .catch_warnings ():
226
+ warnings .simplefilter ("ignore" , UserWarning )
227
+ return at .nnet .logsoftmax (x , axis = axis )
228
+
229
+
214
230
def logbern (log_p ):
215
231
if np .isnan (log_p ):
216
232
raise FloatingPointError ("log_p can't be nan." )
Original file line number Diff line number Diff line change 30
30
kronecker ,
31
31
log1mexp ,
32
32
log1mexp_numpy ,
33
+ log_softmax ,
33
34
logdet ,
34
35
logdiffexp ,
35
36
logdiffexp_numpy ,
36
37
probit ,
38
+ softmax ,
37
39
)
38
40
from pymc .tests .helpers import SeededTest , verify_grad
39
41
@@ -265,3 +267,24 @@ def test_invlogit_deprecation_warning():
265
267
assert not record
266
268
267
269
assert np .isclose (res , res_zero_eps )
270
+
271
+
272
+ @pytest .mark .parametrize (
273
+ "aesara_function, pymc_wrapper" ,
274
+ [
275
+ (at .nnet .softmax , softmax ),
276
+ (at .nnet .logsoftmax , log_softmax ),
277
+ ],
278
+ )
279
+ def test_softmax_logsoftmax_no_warnings (aesara_function , pymc_wrapper ):
280
+ """Test that wrappers for aesara functions do not issue Warnings"""
281
+
282
+ vector = at .vector ("vector" )
283
+ with pytest .warns (None ) as record :
284
+ aesara_function (vector )
285
+ warnings = {warning .category for warning in record .list }
286
+ assert warnings == {UserWarning , FutureWarning }
287
+
288
+ with pytest .warns (None ) as record :
289
+ pymc_wrapper (vector )
290
+ assert not record
You can’t perform that action at this time.
0 commit comments