Skip to content

Commit 6e0d317

Browse files
committed
Allow Minibatch logp on derived RVs
1 parent 623ca42 commit 6e0d317

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pymc/variational/minibatch_rv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from pytensor.graph import Apply, Op
2121
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable
2222

23-
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
23+
from pymc.logprob.abstract import MeasurableOp, _logprob
24+
from pymc.logprob.basic import logp
2425

2526

2627
class MinibatchRandomVariable(MeasurableOp, Op):
@@ -99,4 +100,4 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor
99100
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
100101
[value] = values
101102
rv, *total_size = inputs
102-
return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape)
103+
return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape)

tests/variational/test_minibatch_rv.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import numpy as np
1515
import pytensor
16+
import pytensor.tensor as pt
1617
import pytest
1718

1819
from scipy import stats as st
@@ -186,3 +187,12 @@ def test_minibatch_parameter_and_value(self):
186187
with m:
187188
pm.set_data({"AD": rng.normal(size=1000)})
188189
assert logp_fn(ip) != logp_fn(ip)
190+
191+
def test_derived_rv(self):
192+
"""Test we can obtain a minibatch logp out of a derived RV."""
193+
dist = pt.clip(pm.Normal.dist(0, 1, size=(1,)), -1, 1)
194+
mb_dist = create_minibatch_rv(dist, total_size=(2,))
195+
np.testing.assert_allclose(
196+
pm.logp(mb_dist, -1).eval(),
197+
pm.logp(dist, -1).eval() * 2,
198+
)

0 commit comments

Comments
 (0)