@@ -693,37 +693,42 @@ def test_not_implemented_discrete_rv_transform():
693
693
694
694
def test_negated_discrete_rv_transform ():
695
695
p = 0.7
696
- rv = - Bernoulli .dist (p = p )
696
+ rv = - Bernoulli .dist (p = p , shape = ( 4 ,) )
697
697
vv = rv .type ()
698
- logp_fn = pytensor .function ([vv ], logp (rv , vv ))
699
698
700
699
# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
701
- assert logp_fn ( - 2 ) == - np . inf
702
- np .testing .assert_allclose (logp_fn ( - 1 ), np . log ( p ))
703
- np .testing . assert_allclose ( logp_fn ( 0 ), np .log (1 - p ))
704
- assert logp_fn ( 1 ) == - np . inf
700
+ logp_fn = pytensor . function ([ vv ], logp ( rv , vv ))
701
+ np .testing .assert_allclose (
702
+ logp_fn ([ - 2 , - 1 , 0 , 1 ]), [ - np .inf , np . log ( p ), np .log (1 - p ), - np . inf ]
703
+ )
705
704
706
- # Logcdf and icdf not supported yet
707
- for func in (logcdf , icdf ):
708
- with pytest .raises (NotImplementedError ):
709
- func (rv , 0 )
705
+ logcdf_fn = pytensor .function ([vv ], logcdf (rv , vv ))
706
+ np .testing .assert_allclose (logcdf_fn ([- 2 , - 1 , 0 , 1 ]), [- np .inf , np .log (p ), 0 , 0 ])
707
+
708
+ with pytest .raises (NotImplementedError ):
709
+ icdf (rv , [- 2 , - 1 , 0 , 1 ])
710
710
711
711
712
712
def test_shifted_discrete_rv_transform ():
713
713
p = 0.7
714
714
rv = Bernoulli .dist (p = p ) + 5
715
715
vv = rv .type ()
716
- rv_logp_fn = pytensor .function ([vv ], logp (rv , vv ))
717
716
717
+ rv_logp_fn = pytensor .function ([vv ], logp (rv , vv ))
718
718
assert rv_logp_fn (4 ) == - np .inf
719
719
np .testing .assert_allclose (rv_logp_fn (5 ), np .log (1 - p ))
720
720
np .testing .assert_allclose (rv_logp_fn (6 ), np .log (p ))
721
721
assert rv_logp_fn (7 ) == - np .inf
722
722
723
- # Logcdf and icdf not supported yet
724
- for func in (logcdf , icdf ):
725
- with pytest .raises (NotImplementedError ):
726
- func (rv , 0 )
723
+ rv_logcdf_fn = pytensor .function ([vv ], logcdf (rv , vv ))
724
+ assert rv_logcdf_fn (4 ) == - np .inf
725
+ np .testing .assert_allclose (rv_logcdf_fn (5 ), np .log (1 - p ))
726
+ np .testing .assert_allclose (rv_logcdf_fn (6 ), 0 )
727
+ assert rv_logcdf_fn (7 ) == 0
728
+
729
+ # icdf not supported yet
730
+ with pytest .raises (NotImplementedError ):
731
+ icdf (rv , 0 )
727
732
728
733
729
734
@pytest .mark .xfail (reason = "Check not implemented yet" )
0 commit comments