Skip to content

Commit 14a10b7

Browse files
Refactor sample return (#6546)
* Extract return part of `pm.sample` * Speed up `test_mcmc` * Group tests related to `pm.sample` return parameters * Consolidate tests of return options Removes a regression test added in #3821 because it took 14 seconds.
1 parent 9b59771 commit 14a10b7

File tree

3 files changed

+209
-220
lines changed

3 files changed

+209
-220
lines changed

pymc/sampling/mcmc.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def sample(
333333
compute_convergence_checks: bool = True,
334334
keep_warning_stat: bool = False,
335335
return_inferencedata: bool = True,
336-
idata_kwargs: dict = None,
336+
idata_kwargs: Optional[Dict[str, Any]] = None,
337337
callback=None,
338338
mp_ctx=None,
339339
model: Optional[Model] = None,
@@ -687,7 +687,36 @@ def sample(
687687

688688
t_sampling = time.time() - t_start
689689

690-
# Wrap chain traces in a MultiTrace
690+
# Packaging, validating and returning the result was extracted
691+
# into a function to make it easier to test and refactor.
692+
return _sample_return(
693+
traces=traces,
694+
tune=tune,
695+
t_sampling=t_sampling,
696+
discard_tuned_samples=discard_tuned_samples,
697+
compute_convergence_checks=compute_convergence_checks,
698+
return_inferencedata=return_inferencedata,
699+
keep_warning_stat=keep_warning_stat,
700+
idata_kwargs=idata_kwargs or {},
701+
model=model,
702+
)
703+
704+
705+
def _sample_return(
706+
*,
707+
traces: Sequence[IBaseTrace],
708+
tune: int,
709+
t_sampling: float,
710+
discard_tuned_samples: bool,
711+
compute_convergence_checks: bool,
712+
return_inferencedata: bool,
713+
keep_warning_stat: bool,
714+
idata_kwargs: Dict[str, Any],
715+
model: Model,
716+
) -> Union[InferenceData, MultiTrace]:
717+
"""Final step of `pm.sampler` that picks/slices chains,
718+
runs diagnostics and converts to the desired return type."""
719+
# Pick and slice chains to keep the maximum number of samples
691720
if discard_tuned_samples:
692721
traces, length = _choose_chains(traces, tune)
693722
else:
@@ -725,8 +754,7 @@ def sample(
725754
idata = None
726755
if compute_convergence_checks or return_inferencedata:
727756
ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
728-
if idata_kwargs:
729-
ikwargs.update(idata_kwargs)
757+
ikwargs.update(idata_kwargs)
730758
idata = pm.to_inference_data(mtrace, **ikwargs)
731759

732760
if compute_convergence_checks:

tests/backends/test_ndarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def test_multitrace_nonunique(self):
123123
with pytest.raises(ValueError):
124124
base.MultiTrace([self.strace0, self.strace1])
125125

126+
def test_multitrace_iter_notimplemented(self):
127+
mtrace = base.MultiTrace([self.strace0])
128+
with pytest.raises(NotImplementedError):
129+
for _ in mtrace:
130+
pass
131+
126132

127133
class TestSqueezeCat:
128134
def setup_method(self):

0 commit comments

Comments
 (0)