Skip to content

Commit c26db8b

Browse files
Luke LBricardoV94
Luke LB
andcommitted
Remove automatic normalization in Categorical and Multinomial distributions
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent eb5177a commit c26db8b

File tree

5 files changed

+82
-20
lines changed

5 files changed

+82
-20
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
145145
- `math.log1mexp` and `math.log1mexp_numpy` will expect negative inputs in the future. A `FutureWarning` is now raised unless `negative_input=True` is set (see [#4860](https://github.com/pymc-devs/pymc/pull/4860)).
146146
- Changed name of `Lognormal` distribution to `LogNormal` to harmonize CamelCase usage for distribution names.
147147
- Attempt to iterate over MultiTrace will raise NotImplementedError.
148+
- Removed silent normalisation of `p` parameters in Categorical and Multinomial distributions (see [#5370](https://github.com/pymc-devs/pymc/pull/5370)).
148149
- ...
149150

150151

Diff for: pymc/distributions/discrete.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
import aesara.tensor as at
1517
import numpy as np
1618

@@ -1233,7 +1235,16 @@ class Categorical(Discrete):
12331235

12341236
@classmethod
12351237
def dist(cls, p, **kwargs):
1236-
1238+
if isinstance(p, np.ndarray) or isinstance(p, list):
1239+
if (np.asarray(p) < 0).any():
1240+
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
1241+
p_sum = np.sum([p], axis=-1)
1242+
if not np.all(np.isclose(p_sum, 1.0)):
1243+
warnings.warn(
1244+
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
1245+
UserWarning,
1246+
)
1247+
p = p / at.sum(p, axis=-1, keepdims=True)
12371248
p = at.as_tensor_variable(floatX(p))
12381249
return super().dist([p], **kwargs)
12391250

@@ -1256,7 +1267,6 @@ def logp(value, p):
12561267
"""
12571268
k = at.shape(p)[-1]
12581269
p_ = p
1259-
p = p_ / at.sum(p_, axis=-1, keepdims=True)
12601270
value_clip = at.clip(value, 0, k - 1)
12611271

12621272
if p.ndim > 1:

Diff for: pymc/distributions/multivariate.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,18 @@ class Multinomial(Discrete):
537537

538538
@classmethod
539539
def dist(cls, n, p, *args, **kwargs):
540-
p = p / at.sum(p, axis=-1, keepdims=True)
540+
if isinstance(p, np.ndarray) or isinstance(p, list):
541+
if (np.asarray(p) < 0).any():
542+
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
543+
p_sum = np.sum([p], axis=-1)
544+
if not np.all(np.isclose(p_sum, 1.0)):
545+
warnings.warn(
546+
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
547+
UserWarning,
548+
)
549+
p = p / at.sum(p, axis=-1, keepdims=True)
541550
n = at.as_tensor_variable(n)
542551
p = at.as_tensor_variable(p)
543-
544552
return super().dist([n, p], *args, **kwargs)
545553

546554
def get_moment(rv, size, n, p):
@@ -582,7 +590,7 @@ def logp(value, n, p):
582590
return check_parameters(
583591
res,
584592
p <= 1,
585-
at.eq(at.sum(p, axis=-1), 1),
593+
at.isclose(at.sum(p, axis=-1), 1),
586594
at.ge(n, 0),
587595
msg="p <= 1, sum(p) = 1, n >= 0",
588596
)

Diff for: pymc/tests/test_distributions.py

+57-15
Original file line numberDiff line numberDiff line change
@@ -2189,19 +2189,43 @@ def test_multinomial(self, n):
21892189
lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
21902190
)
21912191

2192-
def test_multinomial_invalid(self):
2193-
# Test non-scalar invalid parameters/values
2194-
value = np.array([[1, 2, 2], [4, 0, 1]])
2192+
def test_multinomial_invalid_value(self):
2193+
# Test passing non-scalar invalid parameters/values to an otherwise valid Multinomial,
2194+
# evaluates to -inf
2195+
value = np.array([[1, 2, 2], [3, -1, 0]])
2196+
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
2197+
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
21952198

2196-
invalid_dist = Multinomial.dist(n=5, p=[-1, 1, 1], size=2)
2197-
# TODO: Multinomial normalizes p, so it is impossible to trigger p checks
2198-
# with pytest.raises(ParameterValueError):
2199-
with does_not_raise():
2199+
def test_multinomial_negative_p(self):
2200+
# test passing a list/numpy with negative p raises an immediate error
2201+
with pytest.raises(ValueError, match="[-1, 1, 1]"):
2202+
with Model() as model:
2203+
x = Multinomial("x", n=5, p=[-1, 1, 1])
2204+
2205+
def test_multinomial_p_not_normalized(self):
2206+
# test UserWarning is raised for p vals that sum to more than 1
2207+
# and normaliation is triggered
2208+
with pytest.warns(UserWarning, match="[5]"):
2209+
with pm.Model() as m:
2210+
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
2211+
# test stored p-vals have been normalised
2212+
assert np.isclose(m.x.owner.inputs[4].sum().eval(), 1.0)
2213+
2214+
def test_multinomial_negative_p_symbolic(self):
2215+
# Passing symbolic negative p does not raise an immediate error, but evaluating
2216+
# logp raises a ParameterValueError
2217+
with pytest.raises(ParameterValueError):
2218+
value = np.array([[1, 1, 1]])
2219+
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
22002220
pm.logp(invalid_dist, value).eval()
22012221

2202-
value[1] -= 1
2203-
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
2204-
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
2222+
def test_multinomial_p_not_normalized_symbolic(self):
2223+
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
2224+
# logp raises a ParameterValueError
2225+
with pytest.raises(ParameterValueError):
2226+
value = np.array([[1, 1, 1]])
2227+
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
2228+
pm.logp(invalid_dist, value).eval()
22052229

22062230
@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
22072231
@pytest.mark.parametrize(
@@ -2317,12 +2341,22 @@ def test_categorical_bounds(self):
23172341
np.array([-1, -1, 0, 0]),
23182342
],
23192343
)
2320-
def test_categorical_valid_p(self, p):
2321-
with Model():
2322-
x = Categorical("x", p=p)
2344+
def test_categorical_negative_p(self, p):
2345+
with pytest.raises(ValueError, match=f"{p}"):
2346+
with Model():
2347+
x = Categorical("x", p=p)
23232348

2324-
with pytest.raises(ParameterValueError):
2325-
logp(x, 2).eval()
2349+
def test_categorical_negative_p_symbolic(self):
2350+
with pytest.raises(ParameterValueError):
2351+
value = np.array([[1, 1, 1]])
2352+
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
2353+
pm.logp(invalid_dist, value).eval()
2354+
2355+
def test_categorical_p_not_normalized_symbolic(self):
2356+
with pytest.raises(ParameterValueError):
2357+
value = np.array([[1, 1, 1]])
2358+
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
2359+
pm.logp(invalid_dist, value).eval()
23262360

23272361
@pytest.mark.parametrize("n", [2, 3, 4])
23282362
def test_categorical(self, n):
@@ -2333,6 +2367,14 @@ def test_categorical(self, n):
23332367
lambda value, p: categorical_logpdf(value, p),
23342368
)
23352369

2370+
def test_categorical_p_not_normalized(self):
2371+
# test UserWarning is raised for p vals that sum to more than 1
2372+
# and normaliation is triggered
2373+
with pytest.warns(UserWarning, match="[5]"):
2374+
with pm.Model() as m:
2375+
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
2376+
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)
2377+
23362378
@pytest.mark.parametrize("n", [2, 3, 4])
23372379
def test_orderedlogistic(self, n):
23382380
self.check_logp(

Diff for: pymc/tests/test_idata_conversion.py

+1
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def test_multivariate_observations(self):
531531
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)
532532
with pm.Model(coords=coords):
533533
p = pm.Beta("p", 1, 1, size=(3,))
534+
p = p / p.sum()
534535
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
535536
idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True)
536537
test_dict = {

0 commit comments

Comments
 (0)