Skip to content

Remove automatic normalization in Multinomial and Categorical #5331 #5370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `math.log1mexp` and `math.log1mexp_numpy` will expect negative inputs in the future. A `FutureWarning` is now raised unless `negative_input=True` is set (see [#4860](https://github.com/pymc-devs/pymc/pull/4860)).
- Changed name of `Lognormal` distribution to `LogNormal` to harmonize CamelCase usage for distribution names.
- Attempt to iterate over MultiTrace will raise NotImplementedError.
- Removed silent normalisation of `p` parameters in Categorical and Multinomial distributions (see [#5370](https://github.com/pymc-devs/pymc/pull/5370)).
- ...


Expand Down
14 changes: 12 additions & 2 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

import aesara.tensor as at
import numpy as np

Expand Down Expand Up @@ -1233,7 +1235,16 @@ class Categorical(Discrete):

@classmethod
def dist(cls, p, **kwargs):

if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
warnings.warn(
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p = at.as_tensor_variable(floatX(p))
return super().dist([p], **kwargs)

Expand All @@ -1256,7 +1267,6 @@ def logp(value, p):
"""
k = at.shape(p)[-1]
p_ = p
p = p_ / at.sum(p_, axis=-1, keepdims=True)
value_clip = at.clip(value, 0, k - 1)

if p.ndim > 1:
Expand Down
14 changes: 11 additions & 3 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,18 @@ class Multinomial(Discrete):

@classmethod
def dist(cls, n, p, *args, **kwargs):
p = p / at.sum(p, axis=-1, keepdims=True)
if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
warnings.warn(
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
n = at.as_tensor_variable(n)
p = at.as_tensor_variable(p)

return super().dist([n, p], *args, **kwargs)

def get_moment(rv, size, n, p):
Expand Down Expand Up @@ -582,7 +590,7 @@ def logp(value, n, p):
return check_parameters(
res,
p <= 1,
at.eq(at.sum(p, axis=-1), 1),
at.isclose(at.sum(p, axis=-1), 1),
at.ge(n, 0),
msg="p <= 1, sum(p) = 1, n >= 0",
)
Expand Down
72 changes: 57 additions & 15 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,19 +2189,43 @@ def test_multinomial(self, n):
lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
)

def test_multinomial_invalid(self):
# Test non-scalar invalid parameters/values
value = np.array([[1, 2, 2], [4, 0, 1]])
def test_multinomial_invalid_value(self):
# Test passing non-scalar invalid parameters/values to an otherwise valid Multinomial,
# evaluates to -inf
value = np.array([[1, 2, 2], [3, -1, 0]])
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))

invalid_dist = Multinomial.dist(n=5, p=[-1, 1, 1], size=2)
# TODO: Multinomial normalizes p, so it is impossible to trigger p checks
# with pytest.raises(ParameterValueError):
with does_not_raise():
def test_multinomial_negative_p(self):
# test passing a list/numpy with negative p raises an immediate error
with pytest.raises(ValueError, match="[-1, 1, 1]"):
with Model() as model:
x = Multinomial("x", n=5, p=[-1, 1, 1])

def test_multinomial_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pm.Model() as m:
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
# test stored p-vals have been normalised
assert np.isclose(m.x.owner.inputs[4].sum().eval(), 1.0)

def test_multinomial_negative_p_symbolic(self):
# Passing symbolic negative p does not raise an immediate error, but evaluating
# logp raises a ParameterValueError
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

value[1] -= 1
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
def test_multinomial_p_not_normalized_symbolic(self):
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
# logp raises a ParameterValueError
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2317,12 +2341,22 @@ def test_categorical_bounds(self):
np.array([-1, -1, 0, 0]),
],
)
def test_categorical_valid_p(self, p):
with Model():
x = Categorical("x", p=p)
def test_categorical_negative_p(self, p):
with pytest.raises(ValueError, match=f"{p}"):
with Model():
x = Categorical("x", p=p)

with pytest.raises(ParameterValueError):
logp(x, 2).eval()
def test_categorical_negative_p_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

def test_categorical_p_not_normalized_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
pm.logp(invalid_dist, value).eval()

@pytest.mark.parametrize("n", [2, 3, 4])
def test_categorical(self, n):
Expand All @@ -2333,6 +2367,14 @@ def test_categorical(self, n):
lambda value, p: categorical_logpdf(value, p),
)

def test_categorical_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pm.Model() as m:
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)

@pytest.mark.parametrize("n", [2, 3, 4])
def test_orderedlogistic(self, n):
self.check_logp(
Expand Down
1 change: 1 addition & 0 deletions pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def test_multivariate_observations(self):
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)
with pm.Model(coords=coords):
p = pm.Beta("p", 1, 1, size=(3,))
p = p / p.sum()
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True)
test_dict = {
Expand Down