diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 3c4bbb8cf..b29cbd45e 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict, defaultdict -from functools import partial +from functools import partial, reduce from operator import itemgetter import warnings @@ -959,6 +959,19 @@ def single_particle_elbo(rng_key): *(frozenset(f.inputs) & group_plates for f in group_factors) ) elim_plates = group_plates - outermost_plates + plate_to_scale = {} + for name in group_names: + for plate, value in ( + model_trace[name].get("plate_to_scale", {}).items() + ): + if plate in plate_to_scale: + if value != plate_to_scale[plate]: + raise ValueError( + "Expected all enumerated sample sites to share a common scale factor, " + f"but found different scales at plate('{plate}')." + ) + else: + plate_to_scale[plate] = value with funsor.interpretations.normalize: cost = funsor.sum_product.sum_product( funsor.ops.logaddexp, @@ -966,26 +979,20 @@ def single_particle_elbo(rng_key): group_factors, plates=group_plates, eliminate=group_sum_vars | elim_plates, + plate_to_scale=plate_to_scale, ) # TODO: add memoization cost = funsor.optimizer.apply_optimizer(cost) # incorporate the effects of subsampling and handlers.scale through a common scale factor - scales_set = set() - for name in group_names | group_sum_vars: - site_scale = model_trace[name]["scale"] - if site_scale is None: - site_scale = 1.0 - if isinstance(site_scale, jnp.ndarray): - raise ValueError( - "Enumeration only supports scalar handlers.scale" - ) - scales_set.add(float(site_scale)) - if len(scales_set) != 1: - raise ValueError( - "Expected all enumerated sample sites to share a common scale, " - f"but found {len(scales_set)} different scales." - ) - scale = next(iter(scales_set)) + scale = reduce( + funsor.ops.mul, + [ + value + for plate, value in plate_to_scale.items() + if plate not in elim_plates + ], + 1.0, + ) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/setup.py b/setup.py index bba87d335..334905d17 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "dev": [ "dm-haiku", "flax", - "funsor>=0.4.1", + "funsor>=0.4.6", "graphviz", "jaxns>=2.0.1", "matplotlib", diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index d464fb33a..2d166cce7 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2314,14 +2314,10 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) @@ -2389,20 +2385,16 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_3(scale): # Enumerate: a - # Subsample: a, b, c + # Subsample: b, c # [ a - [----> b ] # [ \ [ ] # [ - [- [-> c ] ] @@ -2464,14 +2456,10 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-3) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-3) + assert_equal(actual_grads, expected_grads, prec=1e-5) def test_guide_plate_contraction():