Skip to content

Commit 48a8b6b

Browse files
committed
Remove special logprob case for MaxNeg (used for Min logprob)
1 parent 4fb475b commit 48a8b6b

File tree

2 files changed

+57
-181
lines changed

2 files changed

+57
-181
lines changed

pymc/logprob/order.py

Lines changed: 43 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,25 @@
3333
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
36-
37-
3836
from typing import cast
3937

4038
import pytensor.tensor as pt
4139

4240
from pytensor.graph.basic import Apply
4341
from pytensor.graph.fg import FunctionGraph
4442
from pytensor.graph.rewriting.basic import node_rewriter
45-
from pytensor.tensor.elemwise import Elemwise
4643
from pytensor.tensor.math import Max
47-
from pytensor.tensor.random.op import RandomVariable
4844
from pytensor.tensor.variable import TensorVariable
4945

5046
from pymc.logprob.abstract import (
47+
MeasurableElemwise,
48+
MeasurableOp,
5149
MeasurableOpMixin,
5250
_logcdf_helper,
5351
_logprob,
5452
_logprob_helper,
5553
)
5654
from pymc.logprob.rewriting import measurable_ir_rewrites_db
57-
from pymc.logprob.utils import find_negated_var
5855
from pymc.math import logdiffexp
5956
from pymc.pytensorf import constant_fold
6057

@@ -73,25 +70,41 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
7370
if rv_map_feature is None:
7471
return None # pragma: no cover
7572

76-
if isinstance(node.op, MeasurableMax):
77-
return None # pragma: no cover
73+
if isinstance(node.op, MeasurableMax | MeasurableMaxDiscrete):
74+
return None
7875

79-
base_var = cast(TensorVariable, node.inputs[0])
76+
[base_var] = node.inputs
8077

8178
if base_var.owner is None:
8279
return None
8380

8481
if not rv_map_feature.request_measurable(node.inputs):
8582
return None
8683

87-
# Non-univariate distributions and non-RVs must be rejected
88-
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
84+
# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
85+
if isinstance(base_var.owner.op, MeasurableElemwise):
86+
latent_base_vars = [
87+
var
88+
for var in base_var.owner.inputs
89+
if (var.owner and isinstance(var.owner.op, MeasurableOp))
90+
]
91+
if len(latent_base_vars) != 1:
92+
return None
93+
[latent_base_var] = latent_base_vars
94+
else:
95+
latent_base_var = base_var
96+
97+
latent_op = latent_base_var.owner.op
98+
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
8999
return None
90100

91101
# univariate i.i.d. test which also rules out other distributions
92-
for params in base_var.owner.op.dist_params(base_var.owner):
93-
if not all(params.type.broadcastable):
94-
return None
102+
if not all(
103+
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
104+
):
105+
return None
106+
107+
base_var = cast(TensorVariable, base_var)
95108

96109
if node.op.axis is None:
97110
axis = tuple(range(base_var.ndim))
@@ -102,16 +115,11 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
102115
return None
103116

104117
# distinguish measurable discrete and continuous (because logprob is different)
105-
measurable_max: Max
106-
if base_var.type.dtype.startswith("int"):
107-
measurable_max = MeasurableMaxDiscrete(axis)
108-
else:
109-
measurable_max = MeasurableMax(axis)
110-
111-
max_rv_node = measurable_max.make_node(base_var)
112-
max_rv = max_rv_node.outputs
113-
114-
return max_rv
118+
measurable_max_class = (
119+
MeasurableMaxDiscrete if latent_base_var.type.dtype.startswith("int") else MeasurableMax
120+
)
121+
max_rv = cast(TensorVariable, measurable_max_class(axis)(base_var))
122+
return [max_rv]
115123

116124

117125
measurable_ir_rewrites_db.register(
@@ -127,13 +135,13 @@ def max_logprob(op, values, base_rv, **kwargs):
127135
r"""Compute the log-likelihood graph for the `Max` operation."""
128136
(value,) = values
129137

130-
logprob = _logprob_helper(base_rv, value)
131-
logcdf = _logcdf_helper(base_rv, value)
138+
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
139+
bcast_value = pt.broadcast_to(value, base_rv_shape)
140+
logprob = _logprob_helper(base_rv, bcast_value)[0]
141+
logcdf = _logcdf_helper(base_rv, bcast_value)[0]
132142

133-
[n] = constant_fold([base_rv.size])
134-
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)
135-
136-
return logprob
143+
n = pt.prod(base_rv_shape)
144+
return (n - 1) * logcdf + logprob + pt.math.log(n)
137145

138146

139147
@_logprob.register(MeasurableMaxDiscrete)
@@ -146,126 +154,11 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
146154
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
147155
"""
148156
(value,) = values
149-
logcdf = _logcdf_helper(base_rv, value)
150-
logcdf_prev = _logcdf_helper(base_rv, value - 1)
151-
152-
[n] = constant_fold([base_rv.size])
153-
154-
logprob = logdiffexp(n * logcdf, n * logcdf_prev)
155-
156-
return logprob
157-
158-
159-
class MeasurableMaxNeg(MeasurableOpMixin, Max):
160-
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
161-
This shows up in the graph of min, which is (neg(max(neg(x)))."""
162-
163-
164-
class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
165-
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""
166-
167-
168-
@node_rewriter(tracks=[Max])
169-
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
170-
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
171-
172-
if rv_map_feature is None:
173-
return None # pragma: no cover
174-
175-
if isinstance(node.op, MeasurableMaxNeg):
176-
return None # pragma: no cover
177-
178-
base_var = cast(TensorVariable, node.inputs[0])
179-
180-
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
181-
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
182-
return None
183-
184-
base_rv = find_negated_var(base_var)
185-
186-
# negation is rv * (-1). Hence the scalar_op must be Mul
187-
if base_rv is None:
188-
return None
189-
190-
# Non-univariate distributions and non-RVs must be rejected
191-
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
192-
return None
193-
194-
# univariate i.i.d. test which also rules out other distributions
195-
for params in base_rv.owner.op.dist_params(base_rv.owner):
196-
if not all(params.type.broadcastable):
197-
return None
198157

199-
if node.op.axis is None:
200-
axis = tuple(range(base_var.ndim))
201-
else:
202-
# Check whether axis is supported or not
203-
axis = tuple(sorted(node.op.axis))
204-
if axis != tuple(range(base_var.ndim)):
205-
return None
206-
207-
if not rv_map_feature.request_measurable([base_rv]):
208-
return None
209-
210-
# distinguish measurable discrete and continuous (because logprob is different)
211-
measurable_min: Max
212-
if base_rv.type.dtype.startswith("int"):
213-
measurable_min = MeasurableDiscreteMaxNeg(axis)
214-
else:
215-
measurable_min = MeasurableMaxNeg(axis)
216-
217-
return measurable_min.make_node(base_rv).outputs
218-
219-
220-
measurable_ir_rewrites_db.register(
221-
"find_measurable_max_neg",
222-
find_measurable_max_neg,
223-
"basic",
224-
"min",
225-
)
226-
227-
228-
@_logprob.register(MeasurableMaxNeg)
229-
def max_neg_logprob(op, values, base_rv, **kwargs):
230-
r"""Compute the log-likelihood graph for the `Max` operation.
231-
The formula that we use here is :
232-
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
233-
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
234-
"""
235-
(value,) = values
236-
237-
logprob = _logprob_helper(base_rv, -value)
238-
logcdf = _logcdf_helper(base_rv, -value)
239-
240-
[n] = constant_fold([base_rv.size])
241-
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)
242-
243-
return logprob
244-
245-
246-
@_logprob.register(MeasurableDiscreteMaxNeg)
247-
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
248-
r"""Compute the log-likelihood graph for the `Max` operation.
249-
250-
The formula that we use here is :
251-
.. math::
252-
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
253-
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
254-
"""
255-
256-
(value,) = values
257-
258-
# The cdf of a negative variable is the survival at the negated value
259-
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
260-
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))
261-
262-
[n] = constant_fold([base_rv.size])
263-
264-
# Now we can use the same expression as the discrete max
265-
logprob = pt.where(
266-
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
267-
-pt.inf,
268-
logdiffexp(n * logcdf_prev, n * logcdf),
269-
)
158+
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
159+
bcast_value = pt.broadcast_to(value, base_rv_shape)
160+
logcdf = _logcdf_helper(base_rv, bcast_value)[0]
161+
logcdf_prev = _logcdf_helper(base_rv, bcast_value - 1)[0]
270162

271-
return logprob
163+
n = pt.prod(base_rv_shape)
164+
return logdiffexp(n * logcdf, n * logcdf_prev)

tests/logprob/test_order.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def test_argmax():
5353
"""Test whether the logprob for ```pt.argmax``` is correctly rejected"""
5454
x = pt.random.normal(0, 1, size=(3,))
5555
x.name = "x"
56-
x_max = pt.argmax(x, axis=-1)
57-
x_max_value = pt.vector("x_max_value")
56+
x_argmax = pt.argmax(x, axis=-1)
57+
x_max_value = pt.scalar("x_max_value", dtype=x_argmax.type.dtype)
5858

5959
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")):
60-
x_max_logprob = logp(x_max, x_max_value)
60+
logp(x_argmax, x_max_value)
6161

6262

6363
@pytest.mark.parametrize(
@@ -72,26 +72,9 @@ def test_non_iid_fails(pt_op):
7272
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
7373
x.name = "x"
7474
x_m = pt_op(x, axis=-1)
75-
x_m_value = pt.vector("x_value")
75+
x_m_value = pt.scalar("x_value")
7676
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
77-
x_max_logprob = logp(x_m, x_m_value)
78-
79-
80-
@pytest.mark.parametrize(
81-
"pt_op",
82-
[
83-
pt.max,
84-
pt.min,
85-
],
86-
)
87-
def test_non_rv_fails(pt_op):
88-
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
89-
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
90-
x.name = "x"
91-
x_m = pt_op(x, axis=-1)
92-
x_m_value = pt.vector("x_value")
93-
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
94-
x_max_logprob = logp(x_m, x_m_value)
77+
logp(x_m, x_m_value)
9578

9679

9780
@pytest.mark.parametrize(
@@ -107,9 +90,9 @@ def test_multivariate_rv_fails(pt_op):
10790
x = pm.StickBreakingWeights.dist(_alpha, _k)
10891
x.name = "x"
10992
x_m = pt_op(x, axis=-1)
110-
x_m_value = pt.vector("x_value")
93+
x_m_value = pt.scalar("x_value")
11194
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
112-
x_max_logprob = logp(x_m, x_m_value)
95+
logp(x_m, x_m_value)
11396

11497

11598
@pytest.mark.parametrize(
@@ -124,9 +107,9 @@ def test_categorical(pt_op):
124107
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
125108
x.name = "x"
126109
x_m = pt_op(x, axis=-1)
127-
x_m_value = pt.vector("x_value")
110+
x_m_value = pt.scalar("x_value", dtype=x.type.dtype)
128111
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
129-
x_max_logprob = logp(x_m, x_m_value)
112+
logp(x_m, x_m_value)
130113

131114

132115
@pytest.mark.parametrize(
@@ -230,19 +213,19 @@ def test_min_non_mul_elemwise_fails():
230213
x = pt.log(pt.random.beta(0, 1, size=(3,)))
231214
x.name = "x"
232215
x_min = pt.min(x, axis=-1)
233-
x_min_value = pt.vector("x_min_value")
216+
x_min_value = pt.scalar("x_min_value")
234217
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
235-
x_min_logprob = logp(x_min, x_min_value)
218+
logp(x_min, x_min_value)
236219

237220

238221
@pytest.mark.parametrize(
239222
"mu, size, value, axis",
240223
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
241224
)
242225
def test_max_discrete(mu, size, value, axis):
243-
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
226+
x = pm.Poisson.dist(name="x", mu=mu, size=size)
244227
x_max = pt.max(x, axis=axis)
245-
x_max_value = pt.scalar("x_max_value")
228+
x_max_value = pt.scalar("x_max_value", dtype=x.type.dtype)
246229
x_max_logprob = logp(x_max, x_max_value)
247230

248231
test_value = value
@@ -265,7 +248,7 @@ def test_max_discrete(mu, size, value, axis):
265248
def test_min_discrete(mu, n, test_value, axis):
266249
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
267250
x_min = pt.min(x, axis=axis)
268-
x_min_value = pt.scalar("x_min_value")
251+
x_min_value = pt.scalar("x_min_value", dtype=x.type.dtype)
269252
x_min_logprob = logp(x_min, x_min_value)
270253

271254
sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)

0 commit comments

Comments
 (0)