Skip to content

Commit 8ea8b3b

Browse files
committed
Do not squeeze single chain in sample_stats_to_xarray
1 parent 6ab0c03 commit 8ea8b3b

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

pymc/backends/arviz.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,12 @@ def sample_stats_to_xarray(self):
297297
continue
298298
if self.warmup_trace:
299299
data_warmup[name] = np.array(
300-
self.warmup_trace.get_sampler_stats(stat, combine=False)
300+
self.warmup_trace.get_sampler_stats(stat, combine=False, squeeze=False)
301301
)
302302
if self.posterior_trace:
303-
data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))
303+
data[name] = np.array(
304+
self.posterior_trace.get_sampler_stats(stat, combine=False, squeeze=False)
305+
)
304306

305307
return (
306308
dict_to_dataset(

pymc/tests/backends/test_arviz.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
)
2222
from pymc.exceptions import ImputationWarning
2323

24+
# Turn all warnings into errors for this module
25+
pytestmark = pytest.mark.filterwarnings("error")
26+
2427

2528
@pytest.fixture(scope="module")
2629
def eight_schools_params():
@@ -635,7 +638,9 @@ def test_include_transformed(self):
635638
pm.Uniform("p", 0, 1)
636639

637640
# First check that the default is to exclude the transformed variables
638-
sample_kwargs = dict(tune=5, draws=7, chains=2, cores=1)
641+
sample_kwargs = dict(
642+
tune=5, draws=7, chains=2, cores=1, compute_convergence_checks=False
643+
)
639644
inference_data = pm.sample(**sample_kwargs, step=pm.Metropolis())
640645
assert "p_interval__" not in inference_data.posterior
641646

@@ -647,6 +652,17 @@ def test_include_transformed(self):
647652
)
648653
assert "p_interval__" in inference_data.posterior
649654

655+
@pytest.mark.parametrize("chains", (1, 2))
656+
def test_single_chain(self, chains):
657+
# Test that no UserWarning is raised when sampling with NUTS defaults
658+
659+
# When this test was added, a `UserWarning: More chains (500) than draws (1)` used to be issued
660+
# when sampling with a single chain
661+
warnings.simplefilter("error")
662+
with pm.Model():
663+
pm.Normal("x")
664+
pm.sample(chains=chains, return_inferencedata=True)
665+
650666

651667
class TestPyMCWarmupHandling:
652668
@pytest.mark.parametrize("save_warmup", [False, True])

0 commit comments

Comments
 (0)