Skip to content

Commit 618634b

Browse files
committed
Do not include RVs in graph of symbolic_normalizing_constant
1 parent a87d95e commit 618634b

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

pymc/pytensorf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def constant_fold(
979979
"""
980980
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)
981981

982-
# The default rewrite_graph includes a constand_folding that is not always applied.
982+
# The default rewrite_graph includes a constant_folding that is not always applied.
983983
# We use an unconditional constant_folding as the last pass to ensure a thorough constant folding.
984984
rewrite_graph(fg)
985985
topo_unconditional_constant_folding.apply(fg)

pymc/variational/opvi.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from pymc.pytensorf import (
7575
SeedSequenceSeed,
7676
compile,
77+
constant_fold,
7778
find_rng_nodes,
7879
reseed_rngs,
7980
)
@@ -1105,7 +1106,10 @@ def symbolic_normalizing_constant(self):
11051106
t = self.to_flat_input(
11061107
pt.max(
11071108
[
1108-
get_scaling(v.owner.inputs[1:], v.shape)
1109+
get_scaling(
1110+
v.owner.inputs[1:],
1111+
constant_fold([v.owner.inputs[0].shape], raise_not_constant=False),
1112+
)
11091113
for v in self.group
11101114
if isinstance(v.owner.op, MinibatchRandomVariable)
11111115
]
@@ -1272,7 +1276,10 @@ def symbolic_normalizing_constant(self):
12721276
t = pt.max(
12731277
self.collect("symbolic_normalizing_constant")
12741278
+ [
1275-
get_scaling(obs.owner.inputs[1:], obs.shape)
1279+
get_scaling(
1280+
obs.owner.inputs[1:],
1281+
constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False),
1282+
)
12761283
for obs in self.model.observed_RVs
12771284
if isinstance(obs.owner.op, MinibatchRandomVariable)
12781285
]

tests/variational/test_opvi.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import pymc as pm
2222

23+
from pymc.testing import assert_no_rvs
2324
from pymc.variational import opvi
2425
from pymc.variational.approximations import (
2526
Empirical,
@@ -278,3 +279,18 @@ def test_logq_globals(three_var_approx):
278279
es = symbolic_logq.eval()
279280
assert e.shape == ()
280281
assert es.shape == (2,)
282+
283+
284+
def test_symbolic_normalizing_constant_no_rvs():
285+
# Test that RVs aren't included in the graph of symbolic_normalizing_constant
286+
rng = np.random.default_rng()
287+
288+
with pm.Model() as m:
289+
obs = pm.Data("obs", rng.normal(size=(1000,)))
290+
obs_batch = pm.Minibatch(obs, batch_size=128)
291+
x = pm.Normal("x") # Need at least one Free_RV in the graph
292+
y_hat = pm.Flat("y_hat", observed=obs_batch, total_size=1000)
293+
294+
step = pm.ADVI()
295+
296+
assert_no_rvs(step.approx.symbolic_normalizing_constant)

0 commit comments

Comments
 (0)