Skip to content

Fix obs broadcast mismatch #4700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def change_rv_size(
size = rv_node.op._infer_shape(size, dist_params)
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)

# Make sure the new size is a tensor. This helps to not unnecessarily pick
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Make sure the new size is a tensor. This helps to not unnecessarily pick
# Make sure the new size is an int64 tensor. This helps to not unnecessarily pick

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also link to the issue here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's discoverable through the commits.
But on #4696 we'll make a few more robustness changes to change_rv_size. I can link the issue when rebasing the reintro_shape branch.

# up a `Cast` in some cases
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
rv_var = new_rv_node.outputs[-1]
rv_var.name = name
Expand Down
2 changes: 0 additions & 2 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,6 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
):
raise TypeError("Observed data cannot consist of symbolic variables.")

data = pandas_to_array(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause problems if a DataFrame is passed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the first things in make_obs_var is another call to pandas_to_array, that's why this one can go away


# `rv_var` is potentially a new variable (e.g. the original
# variable could have its size changed to match the data, or be a
# new graph that accounts for missing data)
Expand Down
11 changes: 10 additions & 1 deletion pymc3/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
import scipy.sparse as sps

from aesara.graph.basic import Variable, ancestors
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.tensor.random.basic import normal, uniform
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand Down Expand Up @@ -67,6 +67,15 @@ def test_change_rv_size():
assert rv_newer.ndim == 3
assert rv_newer.eval().shape == (4, 3, 2)

# Make sure we avoid introducing a `Cast` by converting the new size before
# constructing the new `RandomVariable`
rv = normal(0, 1)
new_size = np.array([4, 3], dtype="int32")
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
assert rv_newer.ndim == 2
assert isinstance(rv_newer.owner.inputs[1], Constant)
assert rv_newer.eval().shape == (4, 3)


class TestBroadcasting:
def test_make_shared_replacements(self):
Expand Down
11 changes: 4 additions & 7 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,10 @@ def check_rv_size(self):
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
for size, expected in zip(sizes_to_check, sizes_expected):
actual = change_rv_size(self.pymc_rv, size).eval().shape
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
actual = tuple(pymc_rv.shape.eval())
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"

# test negative sizes raise
for size in [-2, (3, -2)]:
with pytest.raises(ValueError):
change_rv_size(self.pymc_rv, size).eval()

# test multi-parameters sampling for univariate distributions (with univariate inputs)
if self.pymc_dist.rv_op.ndim_supp == 0 and sum(self.pymc_dist.rv_op.ndims_params) == 0:
params = {
Expand All @@ -400,7 +396,8 @@ def check_rv_size(self):
(5, self.repeated_params_shape),
]
for size, expected in zip(sizes_to_check, sizes_expected):
actual = change_rv_size(self.pymc_rv, size).eval().shape
pymc_rv = self.pymc_dist.dist(**params, size=size)
actual = tuple(pymc_rv.shape.eval())
assert actual == expected

def validate_tests_list(self):
Expand Down