@@ -481,16 +481,6 @@ def test_normal_scalar(self):
481
481
chains = nchains ,
482
482
)
483
483
484
- # test that trace is used in ppc
485
- with pm .Model () as model_ppc :
486
- mu = pm .Normal ("mu" , 0.0 , 1.0 )
487
- a = pm .Normal ("a" , mu = mu , sigma = 1 )
488
-
489
- ppc = pm .sample_posterior_predictive (
490
- trace = trace , model = model_ppc , return_inferencedata = False
491
- )
492
- assert "a" in ppc
493
-
494
484
with model :
495
485
# test list input
496
486
ppc0 = pm .sample_posterior_predictive (
@@ -550,6 +540,51 @@ def test_normal_scalar_idata(self):
550
540
ppc = pm .sample_posterior_predictive (idata , return_inferencedata = False )
551
541
assert ppc ["a" ].shape == (nchains , ndraws )
552
542
543
+ def test_external_trace (self ):
544
+ nchains = 2
545
+ ndraws = 500
546
+ with pm .Model () as model :
547
+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
548
+ a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
549
+ trace = pm .sample (
550
+ draws = ndraws ,
551
+ chains = nchains ,
552
+ )
553
+
554
+ # test that trace is used in ppc
555
+ with pm .Model () as model_ppc :
556
+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
557
+ a = pm .Normal ("a" , mu = mu , sigma = 1 )
558
+
559
+ ppc = pm .sample_posterior_predictive (
560
+ trace = trace , model = model_ppc , return_inferencedata = False
561
+ )
562
+ assert list (ppc .keys ()) == ["a" ]
563
+
564
+ @pytest .mark .xfail (reason = "Auto-imputation of variables not supported in this setting" )
565
+ def test_external_trace_det (self ):
566
+ nchains = 2
567
+ ndraws = 500
568
+ with pm .Model () as model :
569
+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
570
+ a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
571
+ b = pm .Deterministic ("b" , a + 1 )
572
+ trace = pm .sample (
573
+ draws = ndraws ,
574
+ chains = nchains ,
575
+ )
576
+
577
+ # test that trace is used in ppc
578
+ with pm .Model () as model_ppc :
579
+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
580
+ a = pm .Normal ("a" , mu = mu , sigma = 1 )
581
+ b = pm .Deterministic ("b" , a + 1 )
582
+
583
+ ppc = pm .sample_posterior_predictive (
584
+ trace = trace , model = model_ppc , return_inferencedata = False
585
+ )
586
+ assert list (ppc .keys ()) == ["a" , "b" ]
587
+
553
588
def test_normal_vector (self ):
554
589
with pm .Model () as model :
555
590
mu = pm .Normal ("mu" , 0.0 , 1.0 )
0 commit comments