Skip to content

Commit ce447cc

Browse files
Make forward sampling functions return InferenceData by default (#5073)
The `return_inferencedata=True` option is new & default now for * `sample_posterior_predictive` * `sample_posterior_predictive_w` * `sample_prior_predictive` Co-authored-by: Osvaldo Martin <[email protected]>
1 parent 5012274 commit ce447cc

11 files changed

+236
-150
lines changed

Diff for: RELEASE-NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
99
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`. Furthermore `initval` no longer assigns a `tag.test_value` on tensors since the initial values are now kept track of by the model object ([see #4913](https://github.com/pymc-devs/pymc/pull/4913)).
1010
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc/pull/4744)).
11+
- `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object
12+
by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)).
1113
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
1214
-`pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
1315
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).

Diff for: pymc/sampling.py

+57-21
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,11 @@ def sample(
272272
callback=None,
273273
jitter_max_retries=10,
274274
*,
275-
return_inferencedata=None,
275+
return_inferencedata=True,
276276
idata_kwargs: dict = None,
277277
mp_ctx=None,
278278
**kwargs,
279-
):
279+
) -> Union[InferenceData, MultiTrace]:
280280
r"""Draw samples from the posterior using the given step methods.
281281
282282
Multiple step methods are supported via compound step methods.
@@ -341,9 +341,9 @@ def sample(
341341
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
342342
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
343343
init methods.
344-
return_inferencedata : bool, default=True
344+
return_inferencedata : bool
345345
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
346-
Defaults to `False`, but we'll switch to `True` in an upcoming release.
346+
Defaults to `True`.
347347
idata_kwargs : dict, optional
348348
Keyword arguments for :func:`pymc.to_inference_data`
349349
mp_ctx : multiprocessing.context.BaseContent
@@ -455,9 +455,6 @@ def sample(
455455
if not isinstance(random_seed, abc.Iterable):
456456
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")
457457

458-
if return_inferencedata is None:
459-
return_inferencedata = True
460-
461458
if not discard_tuned_samples and not return_inferencedata:
462459
warnings.warn(
463460
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
@@ -1539,7 +1536,9 @@ def sample_posterior_predictive(
15391536
random_seed=None,
15401537
progressbar: bool = True,
15411538
mode: Optional[Union[str, Mode]] = None,
1542-
) -> Dict[str, np.ndarray]:
1539+
return_inferencedata=True,
1540+
idata_kwargs: dict = None,
1541+
) -> Union[InferenceData, Dict[str, np.ndarray]]:
15431542
"""Generate posterior predictive samples from a model given a trace.
15441543
15451544
Parameters
@@ -1574,12 +1573,17 @@ def sample_posterior_predictive(
15741573
time until completion ("expected time of arrival"; ETA).
15751574
mode:
15761575
The mode used by ``aesara.function`` to compile the graph.
1576+
return_inferencedata : bool
1577+
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
1578+
Defaults to True.
1579+
idata_kwargs : dict, optional
1580+
Keyword arguments for :func:`pymc.to_inference_data`
15771581
15781582
Returns
15791583
-------
1580-
samples : dict
1581-
Dictionary with the variable names as keys, and values numpy arrays containing
1582-
posterior predictive samples.
1584+
arviz.InferenceData or Dict
1585+
An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or
1586+
a dictionary with variable names as keys, and samples as numpy arrays.
15831587
"""
15841588

15851589
_trace: Union[MultiTrace, PointList]
@@ -1728,7 +1732,12 @@ def sample_posterior_predictive(
17281732
for k, ary in ppc_trace.items():
17291733
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
17301734

1731-
return ppc_trace
1735+
if not return_inferencedata:
1736+
return ppc_trace
1737+
ikwargs = dict(model=model)
1738+
if idata_kwargs:
1739+
ikwargs.update(idata_kwargs)
1740+
return pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
17321741

17331742

17341743
def sample_posterior_predictive_w(
@@ -1738,6 +1747,8 @@ def sample_posterior_predictive_w(
17381747
weights: Optional[ArrayLike] = None,
17391748
random_seed: Optional[int] = None,
17401749
progressbar: bool = True,
1750+
return_inferencedata=True,
1751+
idata_kwargs: dict = None,
17411752
):
17421753
"""Generate weighted posterior predictive samples from a list of models and
17431754
a list of traces according to a set of weights.
@@ -1764,12 +1775,18 @@ def sample_posterior_predictive_w(
17641775
Whether or not to display a progress bar in the command line. The bar shows the percentage
17651776
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
17661777
time until completion ("expected time of arrival"; ETA).
1778+
return_inferencedata : bool
1779+
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
1780+
Defaults to True.
1781+
idata_kwargs : dict, optional
1782+
Keyword arguments for :func:`pymc.to_inference_data`
17671783
17681784
Returns
17691785
-------
1770-
samples : dict
1771-
Dictionary with the variables as keys. The values corresponding to the
1772-
posterior predictive samples from the weighted models.
1786+
arviz.InferenceData or Dict
1787+
An ArviZ ``InferenceData`` object containing the posterior predictive samples from the
1788+
weighted models (default), or a dictionary with variable names as keys, and samples as
1789+
numpy arrays.
17731790
"""
17741791
if isinstance(traces[0], InferenceData):
17751792
n_samples = [
@@ -1888,7 +1905,13 @@ def sample_posterior_predictive_w(
18881905
except KeyboardInterrupt:
18891906
pass
18901907
else:
1891-
return {k: np.asarray(v) for k, v in ppc.items()}
1908+
ppc = {k: np.asarray(v) for k, v in ppc.items()}
1909+
if not return_inferencedata:
1910+
return ppc
1911+
ikwargs = dict(model=models)
1912+
if idata_kwargs:
1913+
ikwargs.update(idata_kwargs)
1914+
return pm.to_inference_data(posterior_predictive=ppc, **ikwargs)
18921915

18931916

18941917
def sample_prior_predictive(
@@ -1897,7 +1920,9 @@ def sample_prior_predictive(
18971920
var_names: Optional[Iterable[str]] = None,
18981921
random_seed=None,
18991922
mode: Optional[Union[str, Mode]] = None,
1900-
) -> Dict[str, np.ndarray]:
1923+
return_inferencedata=True,
1924+
idata_kwargs: dict = None,
1925+
) -> Union[InferenceData, Dict[str, np.ndarray]]:
19011926
"""Generate samples from the prior predictive distribution.
19021927
19031928
Parameters
@@ -1913,12 +1938,17 @@ def sample_prior_predictive(
19131938
Seed for the random number generator.
19141939
mode:
19151940
The mode used by ``aesara.function`` to compile the graph.
1941+
return_inferencedata : bool
1942+
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
1943+
Defaults to True.
1944+
idata_kwargs : dict, optional
1945+
Keyword arguments for :func:`pymc.to_inference_data`
19161946
19171947
Returns
19181948
-------
1919-
dict
1920-
Dictionary with variable names as keys. The values are numpy arrays of prior
1921-
samples.
1949+
arviz.InferenceData or Dict
1950+
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
1951+
or a dictionary with variable names as keys and samples as numpy arrays.
19221952
"""
19231953
model = modelcontext(model)
19241954

@@ -1984,7 +2014,13 @@ def sample_prior_predictive(
19842014
for var_name in vars_:
19852015
if var_name in data:
19862016
prior[var_name] = data[var_name]
1987-
return prior
2017+
2018+
if not return_inferencedata:
2019+
return prior
2020+
ikwargs = dict(model=model)
2021+
if idata_kwargs:
2022+
ikwargs.update(idata_kwargs)
2023+
return pm.to_inference_data(prior=prior, **ikwargs)
19882024

19892025

19902026
def _init_jitter(

Diff for: pymc/smc/smc.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def initialize_population(self) -> Dict[str, NDArray]:
179179
self.draws,
180180
var_names=[v.name for v in self.model.unobserved_value_vars],
181181
model=self.model,
182+
return_inferencedata=False,
182183
)
183184

184185
def _initialize_kernel(self):

Diff for: pymc/tests/test_data_container.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,19 @@ def test_sample(self):
5555
prior_trace1 = pm.sample_prior_predictive(1000)
5656
pp_trace1 = pm.sample_posterior_predictive(idata, samples=1000)
5757

58-
assert prior_trace0["b"].shape == (1000,)
59-
assert prior_trace0["obs"].shape == (1000, 100)
60-
assert prior_trace1["obs"].shape == (1000, 200)
58+
assert prior_trace0.prior["b"].shape == (1, 1000)
59+
assert prior_trace0.prior_predictive["obs"].shape == (1, 1000, 100)
60+
assert prior_trace1.prior_predictive["obs"].shape == (1, 1000, 200)
6161

62-
assert pp_trace0["obs"].shape == (1000, 100)
63-
64-
np.testing.assert_allclose(x, pp_trace0["obs"].mean(axis=0), atol=1e-1)
65-
66-
assert pp_trace1["obs"].shape == (1000, 200)
62+
assert pp_trace0.posterior_predictive["obs"].shape == (1, 1000, 100)
63+
np.testing.assert_allclose(
64+
x, pp_trace0.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
65+
)
6766

68-
np.testing.assert_allclose(x_pred, pp_trace1["obs"].mean(axis=0), atol=1e-1)
67+
assert pp_trace1.posterior_predictive["obs"].shape == (1, 1000, 200)
68+
np.testing.assert_allclose(
69+
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
70+
)
6971

7072
def test_sample_posterior_predictive_after_set_data(self):
7173
with pm.Model() as model:
@@ -86,8 +88,10 @@ def test_sample_posterior_predictive_after_set_data(self):
8688
pm.set_data(new_data={"x": x_test})
8789
y_test = pm.sample_posterior_predictive(trace)
8890

89-
assert y_test["obs"].shape == (1000, 3)
90-
np.testing.assert_allclose(x_test, y_test["obs"].mean(axis=0), atol=1e-1)
91+
assert y_test.posterior_predictive["obs"].shape == (1, 1000, 3)
92+
np.testing.assert_allclose(
93+
x_test, y_test.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
94+
)
9195

9296
def test_sample_after_set_data(self):
9397
with pm.Model() as model:
@@ -116,8 +120,10 @@ def test_sample_after_set_data(self):
116120
)
117121
pp_trace = pm.sample_posterior_predictive(new_idata, 1000)
118122

119-
assert pp_trace["obs"].shape == (1000, 3)
120-
np.testing.assert_allclose(new_y, pp_trace["obs"].mean(axis=0), atol=1e-1)
123+
assert pp_trace.posterior_predictive["obs"].shape == (1, 1000, 3)
124+
np.testing.assert_allclose(
125+
new_y, pp_trace.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
126+
)
121127

122128
def test_shared_data_as_index(self):
123129
"""
@@ -130,7 +136,7 @@ def test_shared_data_as_index(self):
130136
alpha = pm.Normal("alpha", 0, 1.5, size=3)
131137
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)
132138

133-
prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
139+
prior_trace = pm.sample_prior_predictive(1000)
134140
idata = pm.sample(
135141
1000,
136142
init=None,
@@ -146,10 +152,10 @@ def test_shared_data_as_index(self):
146152
pm.set_data(new_data={"index": new_index, "y": new_y})
147153
pp_trace = pm.sample_posterior_predictive(idata, 1000, var_names=["alpha", "obs"])
148154

149-
assert prior_trace["alpha"].shape == (1000, 3)
155+
assert prior_trace.prior["alpha"].shape == (1, 1000, 3)
150156
assert idata.posterior["alpha"].shape == (1, 1000, 3)
151-
assert pp_trace["alpha"].shape == (1000, 3)
152-
assert pp_trace["obs"].shape == (1000, 3)
157+
assert pp_trace.posterior_predictive["alpha"].shape == (1, 1000, 3)
158+
assert pp_trace.posterior_predictive["obs"].shape == (1, 1000, 3)
153159

154160
def test_shared_data_as_rv_input(self):
155161
"""

Diff for: pymc/tests/test_distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3249,7 +3249,7 @@ def test_distinct_rvs():
32493249
X_rv = pm.Normal("x")
32503250
Y_rv = pm.Normal("y")
32513251

3252-
pp_samples = pm.sample_prior_predictive(samples=2)
3252+
pp_samples = pm.sample_prior_predictive(samples=2, return_inferencedata=False)
32533253

32543254
assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
32553255

@@ -3259,7 +3259,7 @@ def test_distinct_rvs():
32593259
X_rv = pm.Normal("x")
32603260
Y_rv = pm.Normal("y")
32613261

3262-
pp_samples_2 = pm.sample_prior_predictive(samples=2)
3262+
pp_samples_2 = pm.sample_prior_predictive(samples=2, return_inferencedata=False)
32633263

32643264
assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
32653265

Diff for: pymc/tests/test_distributions_random.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ def ref_rand(mu, rowcov, colcov):
15831583
rowcov=np.eye(3),
15841584
colcov=np.eye(3),
15851585
)
1586-
check = pm.sample_prior_predictive(n_fails)
1586+
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False)
15871587

15881588
ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
15891589

@@ -1922,7 +1922,7 @@ def sample_prior(self, distribution, shape, nested_rvs_info, prior_samples):
19221922
nested_rvs_info,
19231923
)
19241924
with model:
1925-
return pm.sample_prior_predictive(prior_samples)
1925+
return pm.sample_prior_predictive(prior_samples, return_inferencedata=False)
19261926

19271927
@pytest.mark.parametrize(
19281928
["prior_samples", "shape", "mu", "alpha"],
@@ -2380,7 +2380,7 @@ def test_car_rng_fn(sparse):
23802380
with pm.Model(rng_seeder=1):
23812381
car = pm.CAR("car", mu, W, alpha, tau, size=size)
23822382
mn = pm.MvNormal("mn", mu, cov, size=size)
2383-
check = pm.sample_prior_predictive(n_fails)
2383+
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False)
23842384

23852385
p, f = delta, n_fails
23862386
while p <= delta and f > 0:

0 commit comments

Comments
 (0)