Skip to content

Commit 94790c7

Browse files
committed
Handle latest PyMC/PyTensor breaking changes
1 parent dfe3fe0 commit 94790c7

13 files changed

+62
-51
lines changed

Diff for: conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.13.0 # CI was failing to resolve
13+
- pymc>=5.16.1 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

Diff for: conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.13.0 # CI was failing to resolve
13+
- pymc>=5.16.1 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

Diff for: pymc_experimental/distributions/continuous.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
The imports from pymc are not fully replicated here: add imports as necessary.
2020
"""
2121

22-
from typing import List, Tuple, Union
22+
from typing import Tuple, Union
2323

2424
import numpy as np
2525
import pytensor.tensor as pt
@@ -37,8 +37,7 @@
3737

3838
class GenExtremeRV(RandomVariable):
3939
name: str = "Generalized Extreme Value"
40-
ndim_supp: int = 0
41-
ndims_params: List[int] = [0, 0, 0]
40+
signature = "(),(),()->()"
4241
dtype: str = "floatX"
4342
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
4443

@@ -275,7 +274,7 @@ def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
275274

276275
def __new__(cls, name, nu, **kwargs):
277276
if "observed" not in kwargs:
278-
kwargs.setdefault("transform", transforms.log)
277+
kwargs.setdefault("default_transform", transforms.log)
279278
return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
280279

281280
@classmethod

Diff for: pymc_experimental/distributions/discrete.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def log1mexp(x):
3131

3232
class GeneralizedPoissonRV(RandomVariable):
3333
name = "generalized_poisson"
34-
ndim_supp = 0
35-
ndims_params = [0, 0]
34+
signature = "(),()->()"
3635
dtype = "int64"
3736
_print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")
3837

Diff for: pymc_experimental/model/marginal_model.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.graph.replace import graph_replace, vectorize_graph
2222
from pytensor.scan import map as scan_map
2323
from pytensor.tensor import TensorType, TensorVariable
24-
from pytensor.tensor.elemwise import Elemwise
24+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2525
from pytensor.tensor.shape import Shape
2626
from pytensor.tensor.special import log_softmax
2727

@@ -598,7 +598,18 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
598598
fg = FunctionGraph(outputs=output_rvs, clone=False)
599599

600600
non_elemwise_blockers = [
601-
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
601+
o
602+
for node in fg.apply_nodes
603+
if not (
604+
isinstance(node.op, Elemwise)
605+
# Allow expand_dims on the left
606+
or (
607+
isinstance(node.op, DimShuffle)
608+
and not node.op.drop
609+
and node.op.shuffle == sorted(node.op.shuffle)
610+
)
611+
)
612+
for o in node.outputs
602613
]
603614
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
604615
blockers = [var for var in blocker_candidates if var not in output_rvs]
@@ -698,16 +709,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
698709

699710
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
700711
op = rv.owner.op
712+
dist_params = rv.owner.op.dist_params(rv.owner)
701713
if isinstance(op, Bernoulli):
702714
return (0, 1)
703715
elif isinstance(op, Categorical):
704-
p_param = rv.owner.inputs[3]
716+
[p_param] = dist_params
705717
return tuple(range(pt.get_vector_length(p_param)))
706718
elif isinstance(op, DiscreteUniform):
707-
lower, upper = constant_fold(rv.owner.inputs[3:])
719+
lower, upper = constant_fold(dist_params)
708720
return tuple(np.arange(lower, upper + 1))
709721
elif isinstance(op, DiscreteMarkovChain):
710-
P = rv.owner.inputs[0]
722+
P, *_ = dist_params
711723
return tuple(range(pt.get_vector_length(P[-1])))
712724

713725
raise NotImplementedError(f"Cannot compute domain for op {op}")
@@ -827,11 +839,15 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
827839
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
828840
# We do it entirely in logs, though.
829841

830-
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under
831-
# the initial distribution. This is robust to everything the user can throw at it.
832-
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")(
833-
batch_chain_value[..., 0]
834-
)
842+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
843+
# under the initial distribution. This is robust to everything the user can throw at it.
844+
init_dist_value = init_dist_.type()
845+
logp_init_dist = logp(init_dist_, init_dist_value)
846+
# There is a degerate batch dim for lags=1 (the only supported case),
847+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
848+
batch_logp_init_dist = vectorize_graph(
849+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
850+
).squeeze(1)
835851
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
836852

837853
def step_alpha(logp_emission, log_alpha, log_P):

Diff for: pymc_experimental/model/transforms/autoreparam.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytensor
88
import pytensor.tensor as pt
99
import scipy.special
10+
from pymc.distributions import SymbolicRandomVariable
11+
from pymc.exceptions import NotConstantValueError
1012
from pymc.logprob.transforms import Transform
1113
from pymc.model.fgraph import (
1214
ModelDeterministic,
@@ -17,7 +19,7 @@
1719
model_from_fgraph,
1820
model_named,
1921
)
20-
from pymc.pytensorf import toposort_replace
22+
from pymc.pytensorf import constant_fold, toposort_replace
2123
from pytensor.graph.basic import Apply, Variable
2224
from pytensor.tensor.random.op import RandomVariable
2325

@@ -170,14 +172,16 @@ def vip_reparam_node(
170172
dims: List[Variable],
171173
transform: Optional[Transform],
172174
) -> Tuple[ModelDeterministic, ModelNamed]:
173-
if not isinstance(node.op, RandomVariable):
175+
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
174176
raise TypeError("Op should be RandomVariable type")
175-
size = node.inputs[1]
176-
if not isinstance(size, pt.TensorConstant):
177+
rv = node.default_output()
178+
try:
179+
[rv_shape] = constant_fold([rv.shape])
180+
except NotConstantValueError:
177181
raise ValueError("Size should be static for autoreparametrization.")
178182
logit_lam_ = pytensor.shared(
179-
np.zeros(size.data),
180-
shape=size.data,
183+
np.zeros(rv_shape),
184+
shape=rv_shape,
181185
name=f"{name}::lam_logit__",
182186
)
183187
logit_lam = model_named(logit_lam_, *dims)
@@ -216,7 +220,7 @@ def _(
216220
transform: Optional[Transform],
217221
lam: pt.TensorVariable,
218222
) -> ModelDeterministic:
219-
rng, size, _, loc, scale = node.inputs
223+
rng, size, loc, scale = node.inputs
220224
if transform is not None:
221225
raise NotImplementedError("Reparametrization of Normal with Transform is not implemented")
222226
vip_rv_ = pm.Normal.dist(

Diff for: pymc_experimental/tests/model/test_marginal_model.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import pytensor.tensor as pt
88
import pytest
99
from arviz import InferenceData, dict_to_dataset
10-
from pymc import ImputationWarning, inputvars
1110
from pymc.distributions import transforms
1211
from pymc.logprob.abstract import _logprob
1312
from pymc.model.fgraph import fgraph_from_model
13+
from pymc.pytensorf import inputvars
1414
from pymc.util import UNSET
1515
from scipy.special import log_softmax, logsumexp
1616
from scipy.stats import halfnorm, norm
@@ -45,9 +45,7 @@ def disaster_model():
4545
early_rate = pm.Exponential("early_rate", 1.0, initval=3)
4646
late_rate = pm.Exponential("late_rate", 1.0, initval=1)
4747
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
48-
with pytest.warns(ImputationWarning), pytest.warns(
49-
RuntimeWarning, match="invalid value encountered in cast"
50-
):
48+
with pytest.warns(Warning):
5149
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
5250

5351
return disaster_model, years
@@ -294,7 +292,7 @@ def test_recover_marginals_basic():
294292

295293
with m:
296294
prior = pm.sample_prior_predictive(
297-
samples=20,
295+
draws=20,
298296
random_seed=rng,
299297
return_inferencedata=False,
300298
)
@@ -337,7 +335,7 @@ def test_recover_marginals_coords():
337335

338336
with m:
339337
prior = pm.sample_prior_predictive(
340-
samples=20,
338+
draws=20,
341339
random_seed=rng,
342340
return_inferencedata=False,
343341
)
@@ -364,7 +362,7 @@ def test_recover_batched_marginal():
364362

365363
with m:
366364
prior = pm.sample_prior_predictive(
367-
samples=20,
365+
draws=20,
368366
random_seed=rng,
369367
return_inferencedata=False,
370368
)
@@ -394,7 +392,7 @@ def test_nested_recover_marginals():
394392

395393
with m:
396394
prior = pm.sample_prior_predictive(
397-
samples=20,
395+
draws=20,
398396
random_seed=rng,
399397
return_inferencedata=False,
400398
)
@@ -565,7 +563,7 @@ def test_marginalized_transforms(transform, expected_warning):
565563
w=w,
566564
comp_dists=pm.HalfNormal.dist([1, 2, 3]),
567565
initval=initval,
568-
transform=transform,
566+
default_transform=transform,
569567
)
570568
y = pm.Normal("y", 0, sigma, observed=data)
571569

@@ -583,7 +581,7 @@ def test_marginalized_transforms(transform, expected_warning):
583581
),
584582
),
585583
initval=initval,
586-
transform=transform,
584+
default_transform=transform,
587585
)
588586
y = pm.Normal("y", 0, sigma, observed=data)
589587

@@ -710,12 +708,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission):
710708

711709
@pytest.mark.parametrize(
712710
"categorical_emission",
713-
[
714-
False,
715-
# Categorical has a core vector parameter,
716-
# so it is not possible to build a graph that uses elemwise operations exclusively
717-
pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError)),
718-
],
711+
[False, True],
719712
)
720713
def test_marginalized_hmm_categorical_emission(categorical_emission):
721714
"""Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0"""

Diff for: pymc_experimental/tests/statespace/test_SARIMAX.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_interpretable_raises_if_d_nonzero():
331331

332332
def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_interp):
333333
with pymc_mod_interp:
334-
prior = pm.sample_prior_predictive(samples=10)
334+
prior = pm.sample_prior_predictive(draws=10)
335335

336336
prior_outputs = arima_mod_interp.sample_unconditional_prior(prior)
337337
ar_lags = prior.prior.coords["ar_lag"].values - 1

Diff for: pymc_experimental/tests/statespace/test_VARMAX.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def pymc_mod(varma_mod, data):
7171
@pytest.fixture(scope="session")
7272
def idata(pymc_mod, rng):
7373
with pymc_mod:
74-
idata = pm.sample_prior_predictive(samples=10, random_seed=rng)
74+
idata = pm.sample_prior_predictive(draws=10, random_seed=rng)
7575

7676
return idata
7777

Diff for: pymc_experimental/tests/statespace/test_distributions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_lgss_distribution_from_steps(output_name, ss_mod_me, pymc_model_2):
126126
latent_states, obs_states = LinearGaussianStateSpace("states", *matrices, steps=100)
127127
# pylint: enable=unpacking-non-sequence
128128

129-
idata = pm.sample_prior_predictive(samples=10)
129+
idata = pm.sample_prior_predictive(draws=10)
130130
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
131131

132132
assert idata.prior.coords["states_latent_dim_0"].shape == (101,)
@@ -144,7 +144,7 @@ def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2):
144144
"states", *matrices, steps=100, dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
145145
)
146146
# pylint: enable=unpacking-non-sequence
147-
idata = pm.sample_prior_predictive(samples=10)
147+
idata = pm.sample_prior_predictive(draws=10)
148148
delete_rvs_from_model(["states_latent", "states_observed", "states_combined"])
149149

150150
assert idata.prior.coords["time"].shape == (101,)
@@ -198,7 +198,7 @@ def test_lgss_with_time_varying_inputs(output_name, rng):
198198
dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
199199
)
200200
# pylint: enable=unpacking-non-sequence
201-
idata = pm.sample_prior_predictive(samples=10)
201+
idata = pm.sample_prior_predictive(draws=10)
202202

203203
assert idata.prior.coords["time"].shape == (10,)
204204
assert all(

Diff for: pymc_experimental/tests/statespace/test_statespace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
135135
def idata(pymc_mod, rng):
136136
with pymc_mod:
137137
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
138-
idata_prior = pm.sample_prior_predictive(samples=10, random_seed=rng)
138+
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
139139

140140
idata.extend(idata_prior)
141141
return idata
@@ -145,7 +145,7 @@ def idata(pymc_mod, rng):
145145
def idata_exog(exog_pymc_mod, rng):
146146
with exog_pymc_mod:
147147
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
148-
idata_prior = pm.sample_prior_predictive(samples=10, random_seed=rng)
148+
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
149149
idata.extend(idata_prior)
150150
return idata
151151

Diff for: pymc_experimental/tests/statespace/test_structural.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def test_filter_scans_time_varying_design_matrix(rng):
756756
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
757757
pm.Deterministic("Z", Z)
758758

759-
prior = pm.sample_prior_predictive(samples=10)
759+
prior = pm.sample_prior_predictive(draws=10)
760760

761761
prior_Z = prior.prior.Z.values
762762
assert prior_Z.shape == (1, 10, 100, 1, 2)
@@ -790,7 +790,7 @@ def test_extract_components_from_idata(rng):
790790
mod.build_statespace_graph(y)
791791

792792
x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace()
793-
prior = pm.sample_prior_predictive(samples=10)
793+
prior = pm.sample_prior_predictive(draws=10)
794794

795795
filter_prior = mod.sample_conditional_prior(prior)
796796
comp_prior = mod.extract_components_from_idata(filter_prior)

Diff for: requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.13.0
1+
pymc>=5.16.1
22
scikit-learn

0 commit comments

Comments
 (0)