Skip to content

Commit 4fb475b

Browse files
committed
Allow logcdf inference of discrete variables
1 parent 772825e commit 4fb475b

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

pymc/logprob/transforms.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,20 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
232232
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
233233
other_inputs = list(inputs)
234234
measurable_input = other_inputs.pop(op.measurable_input_idx)
235-
236-
# Do not apply rewrite to discrete variables
237-
if measurable_input.type.dtype.startswith("int"):
238-
raise NotImplementedError("logcdf of transformed discrete variables not implemented")
239-
240235
backward_value = op.transform_elemwise.backward(value, *other_inputs)
241236

242237
# Fail if transformation is not injective
243238
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
244239
if isinstance(backward_value, tuple):
245240
raise NotImplementedError
246241

242+
is_discrete = measurable_input.type.dtype.startswith("int")
243+
247244
logcdf = _logcdf_helper(measurable_input, backward_value)
248-
logccdf = pt.log1mexp(logcdf)
245+
if is_discrete:
246+
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
247+
else:
248+
logccdf = pt.log1mexp(logcdf)
249249

250250
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
251251
pass
@@ -271,7 +271,6 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
271271

272272
# The jacobian is used to ensure a value in the supported domain was provided
273273
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
274-
275274
return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)
276275

277276

tests/logprob/test_transforms.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -693,37 +693,42 @@ def test_not_implemented_discrete_rv_transform():
693693

694694
def test_negated_discrete_rv_transform():
695695
p = 0.7
696-
rv = -Bernoulli.dist(p=p)
696+
rv = -Bernoulli.dist(p=p, shape=(4,))
697697
vv = rv.type()
698-
logp_fn = pytensor.function([vv], logp(rv, vv))
699698

700699
# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
701-
assert logp_fn(-2) == -np.inf
702-
np.testing.assert_allclose(logp_fn(-1), np.log(p))
703-
np.testing.assert_allclose(logp_fn(0), np.log(1 - p))
704-
assert logp_fn(1) == -np.inf
700+
logp_fn = pytensor.function([vv], logp(rv, vv))
701+
np.testing.assert_allclose(
702+
logp_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), np.log(1 - p), -np.inf]
703+
)
705704

706-
# Logcdf and icdf not supported yet
707-
for func in (logcdf, icdf):
708-
with pytest.raises(NotImplementedError):
709-
func(rv, 0)
705+
logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
706+
np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0])
707+
708+
with pytest.raises(NotImplementedError):
709+
icdf(rv, [-2, -1, 0, 1])
710710

711711

712712
def test_shifted_discrete_rv_transform():
713713
p = 0.7
714714
rv = Bernoulli.dist(p=p) + 5
715715
vv = rv.type()
716-
rv_logp_fn = pytensor.function([vv], logp(rv, vv))
717716

717+
rv_logp_fn = pytensor.function([vv], logp(rv, vv))
718718
assert rv_logp_fn(4) == -np.inf
719719
np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p))
720720
np.testing.assert_allclose(rv_logp_fn(6), np.log(p))
721721
assert rv_logp_fn(7) == -np.inf
722722

723-
# Logcdf and icdf not supported yet
724-
for func in (logcdf, icdf):
725-
with pytest.raises(NotImplementedError):
726-
func(rv, 0)
723+
rv_logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
724+
assert rv_logcdf_fn(4) == -np.inf
725+
np.testing.assert_allclose(rv_logcdf_fn(5), np.log(1 - p))
726+
np.testing.assert_allclose(rv_logcdf_fn(6), 0)
727+
assert rv_logcdf_fn(7) == 0
728+
729+
# icdf not supported yet
730+
with pytest.raises(NotImplementedError):
731+
icdf(rv, 0)
727732

728733

729734
@pytest.mark.xfail(reason="Check not implemented yet")

0 commit comments

Comments
 (0)