21
21
)
22
22
from pymc .exceptions import ImputationWarning
23
23
24
+ # Turn all warnings into errors for this module
25
+ pytestmark = pytest .mark .filterwarnings ("error" )
26
+
24
27
25
28
@pytest .fixture (scope = "module" )
26
29
def eight_schools_params ():
@@ -635,7 +638,9 @@ def test_include_transformed(self):
635
638
pm .Uniform ("p" , 0 , 1 )
636
639
637
640
# First check that the default is to exclude the transformed variables
638
- sample_kwargs = dict (tune = 5 , draws = 7 , chains = 2 , cores = 1 )
641
+ sample_kwargs = dict (
642
+ tune = 5 , draws = 7 , chains = 2 , cores = 1 , compute_convergence_checks = False
643
+ )
639
644
inference_data = pm .sample (** sample_kwargs , step = pm .Metropolis ())
640
645
assert "p_interval__" not in inference_data .posterior
641
646
@@ -647,6 +652,17 @@ def test_include_transformed(self):
647
652
)
648
653
assert "p_interval__" in inference_data .posterior
649
654
655
+ @pytest .mark .parametrize ("chains" , (1 , 2 ))
656
+ def test_single_chain (self , chains ):
657
+ # Test that no UserWarning is raised when sampling with NUTS defaults
658
+
659
+ # When this test was added, a `UserWarning: More chains (500) than draws (1)` used to be issued
660
+ # when sampling with a single chain
661
+ warnings .simplefilter ("error" )
662
+ with pm .Model ():
663
+ pm .Normal ("x" )
664
+ pm .sample (chains = chains , return_inferencedata = True )
665
+
650
666
651
667
class TestPyMCWarmupHandling :
652
668
@pytest .mark .parametrize ("save_warmup" , [False , True ])
0 commit comments