@@ -410,15 +410,15 @@ def test_not_supported_marginalized_deterministic_and_potential():
410
410
(None , does_not_warn ()),
411
411
(UNSET , does_not_warn ()),
412
412
(transforms .log , does_not_warn ()),
413
- (transforms .Chain ([transforms .log , transforms .logodds ]), does_not_warn ()),
413
+ (transforms .Chain ([transforms .logodds , transforms .log ]), does_not_warn ()),
414
414
(
415
- transforms .Interval (0 , 1 ),
415
+ transforms .Interval (0 , 2 ),
416
416
pytest .warns (
417
417
UserWarning , match = "which depends on the marginalized idx may no longer work"
418
418
),
419
419
),
420
420
(
421
- transforms .Chain ([transforms .log , transforms .Interval (0 , 1 )]),
421
+ transforms .Chain ([transforms .log , transforms .Interval (- 1 , 1 )]),
422
422
pytest .warns (
423
423
UserWarning , match = "which depends on the marginalized idx may no longer work"
424
424
),
@@ -428,7 +428,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
428
428
def test_marginalized_transforms (transform , expected_warning ):
429
429
w = [0.1 , 0.3 , 0.6 ]
430
430
data = [0 , 5 , 10 ]
431
- initval = 0.5 # Value that will be negative on the unconstrained space
431
+ initval = 0.7 # Value that will be negative on the unconstrained space
432
432
433
433
with pm .Model () as m_ref :
434
434
sigma = pm .Mixture (
@@ -467,7 +467,7 @@ def test_marginalized_transforms(transform, expected_warning):
467
467
transform_name = "log"
468
468
else :
469
469
transform_name = transform .name
470
- assert f"sigma_{ transform_name } __" in ip
470
+ assert - np . inf < ip [ f"sigma_{ transform_name } __" ] < 0.0
471
471
np .testing .assert_allclose (m .compile_logp ()(ip ), m_ref .compile_logp ()(ip ))
472
472
473
473
0 commit comments