Skip to content

Make do interventions shared variables by default #7596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from collections.abc import Mapping, Sequence
from typing import Any, Union

from pytensor.graph import ancestors
import pytensor

from pytensor.graph import Constant, ancestors
from pytensor.tensor import TensorVariable

from pymc.logprob.transforms import Transform
Expand Down Expand Up @@ -126,7 +128,9 @@ def observe(
def do(
model: Model,
vars_to_interventions: Mapping[Union["str", TensorVariable], Any],
prune_vars=False,
*,
make_interventions_shared: bool = True,
prune_vars: bool = False,
) -> Model:
"""Replace model variables by intervention variables.

Expand All @@ -140,6 +144,8 @@ def do(
Dictionary that maps model variables (or names) to intervention expressions.
Intervention expressions must have a shape and data type that is compatible
with the original model variable.
make_interventions_shared: bool, defaults to True,
Whether to make constant interventions shared variables.
prune_vars: bool, defaults to False
Whether to prune model variables that are not connected to any observed variables,
after the interventions.
Expand Down Expand Up @@ -170,11 +176,14 @@ def do(

"""
do_mapping = {}
for var, obs in vars_to_interventions.items():
for var, intervention in vars_to_interventions.items():
if isinstance(var, str):
var = model[var]
try:
do_mapping[var] = var.type.filter_variable(obs)
intervention = var.type.filter_variable(intervention)
if make_interventions_shared and isinstance(intervention, Constant):
intervention = pytensor.shared(intervention.data, name=var.name)
do_mapping[var] = intervention
except TypeError as err:
raise TypeError(
"Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables"
Expand Down
43 changes: 43 additions & 0 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import pytest

from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant

import pymc as pm

from pymc import sample_posterior_predictive, set_data
from pymc.distributions.transforms import logodds
from pymc.model.transform.conditioning import (
change_value_transforms,
Expand Down Expand Up @@ -253,6 +256,46 @@ def test_do_self_reference():
np.testing.assert_allclose(draw_x + 100, draw_do_x)


def test_do_make_intervenstions_shared():
with pm.Model(coords={"obs": [1]}) as m:
x = pm.Normal("x", dims="obs")
y = pm.Normal("y", dims="obs")

constant_m = do(m, {x: [0.5]}, make_interventions_shared=False)
constant_x = constant_m["x"]
assert isinstance(constant_x, Constant)
np.testing.assert_array_equal(constant_x.data, [0.5])

shared_m = do(m, {x: [0.5]}, make_interventions_shared=True)
shared_x = shared_m["x"]
assert isinstance(shared_x, SharedVariable)
np.testing.assert_array_equal(shared_x.get_value(borrow=True), [0.5])

with shared_m:
set_data({"x": [0.6, 0.9]}, coords={"obs": [2, 3]})
pp_y = pm.sample_prior_predictive(draws=3).prior["y"]
assert pp_y.sizes == {"chain": 1, "draw": 3, "obs": 2}
assert pp_y.shape == (1, 3, 2)


@pytest.mark.parametrize(
"make_interventions_shared",
[True, pytest.param(False, marks=pytest.mark.xfail(reason="#6876"))],
)
def test_do_sample_posterior_predictive(make_interventions_shared):
# Regression test for https://github.com/pymc-devs/pymc/issues/6977
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Deterministic("b", a * 2)
c = pm.Normal("c", b / 2)

idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})

with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
assert (pp["c"] > 500).all()


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, default_transform=None)
Expand Down
Loading