@@ -333,7 +333,7 @@ def sample(
333
333
compute_convergence_checks : bool = True ,
334
334
keep_warning_stat : bool = False ,
335
335
return_inferencedata : bool = True ,
336
- idata_kwargs : dict = None ,
336
+ idata_kwargs : Optional [ Dict [ str , Any ]] = None ,
337
337
callback = None ,
338
338
mp_ctx = None ,
339
339
model : Optional [Model ] = None ,
@@ -687,7 +687,36 @@ def sample(
687
687
688
688
t_sampling = time .time () - t_start
689
689
690
- # Wrap chain traces in a MultiTrace
690
+ # Packaging, validating and returning the result was extracted
691
+ # into a function to make it easier to test and refactor.
692
+ return _sample_return (
693
+ traces = traces ,
694
+ tune = tune ,
695
+ t_sampling = t_sampling ,
696
+ discard_tuned_samples = discard_tuned_samples ,
697
+ compute_convergence_checks = compute_convergence_checks ,
698
+ return_inferencedata = return_inferencedata ,
699
+ keep_warning_stat = keep_warning_stat ,
700
+ idata_kwargs = idata_kwargs or {},
701
+ model = model ,
702
+ )
703
+
704
+
705
+ def _sample_return (
706
+ * ,
707
+ traces : Sequence [IBaseTrace ],
708
+ tune : int ,
709
+ t_sampling : float ,
710
+ discard_tuned_samples : bool ,
711
+ compute_convergence_checks : bool ,
712
+ return_inferencedata : bool ,
713
+ keep_warning_stat : bool ,
714
+ idata_kwargs : Dict [str , Any ],
715
+ model : Model ,
716
+ ) -> Union [InferenceData , MultiTrace ]:
717
+ """Final step of `pm.sampler` that picks/slices chains,
718
+ runs diagnostics and converts to the desired return type."""
719
+ # Pick and slice chains to keep the maximum number of samples
691
720
if discard_tuned_samples :
692
721
traces , length = _choose_chains (traces , tune )
693
722
else :
@@ -725,8 +754,7 @@ def sample(
725
754
idata = None
726
755
if compute_convergence_checks or return_inferencedata :
727
756
ikwargs : Dict [str , Any ] = dict (model = model , save_warmup = not discard_tuned_samples )
728
- if idata_kwargs :
729
- ikwargs .update (idata_kwargs )
757
+ ikwargs .update (idata_kwargs )
730
758
idata = pm .to_inference_data (mtrace , ** ikwargs )
731
759
732
760
if compute_convergence_checks :
0 commit comments