@@ -31,20 +31,21 @@ class TestBound:
31
31
def test_continuous (self ):
32
32
with pm .Model () as model :
33
33
dist = pm .Normal .dist (mu = 0 , sigma = 1 )
34
- with warnings .catch_warnings ():
35
- warnings .filterwarnings (
36
- "ignore" , "invalid value encountered in add" , RuntimeWarning
37
- )
38
- UnboundedNormal = pm .Bound ("unbound" , dist , transform = None )
39
- InfBoundedNormal = pm .Bound (
40
- "infbound" , dist , lower = - np .inf , upper = np .inf , transform = None
41
- )
42
- LowerNormal = pm .Bound ("lower" , dist , lower = 0 , transform = None )
43
- UpperNormal = pm .Bound ("upper" , dist , upper = 0 , transform = None )
44
- BoundedNormal = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 , transform = None )
45
- LowerNormalTransform = pm .Bound ("lowertrans" , dist , lower = 1 )
46
- UpperNormalTransform = pm .Bound ("uppertrans" , dist , upper = 10 )
47
- BoundedNormalTransform = pm .Bound ("boundedtrans" , dist , lower = 1 , upper = 10 )
34
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
35
+ with warnings .catch_warnings ():
36
+ warnings .filterwarnings (
37
+ "ignore" , "invalid value encountered in add" , RuntimeWarning
38
+ )
39
+ UnboundedNormal = pm .Bound ("unbound" , dist , transform = None )
40
+ InfBoundedNormal = pm .Bound (
41
+ "infbound" , dist , lower = - np .inf , upper = np .inf , transform = None
42
+ )
43
+ LowerNormal = pm .Bound ("lower" , dist , lower = 0 , transform = None )
44
+ UpperNormal = pm .Bound ("upper" , dist , upper = 0 , transform = None )
45
+ BoundedNormal = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 , transform = None )
46
+ LowerNormalTransform = pm .Bound ("lowertrans" , dist , lower = 1 )
47
+ UpperNormalTransform = pm .Bound ("uppertrans" , dist , upper = 10 )
48
+ BoundedNormalTransform = pm .Bound ("boundedtrans" , dist , lower = 1 , upper = 10 )
48
49
49
50
assert joint_logp (LowerNormal , - 1 ).eval () == - np .inf
50
51
assert joint_logp (UpperNormal , 1 ).eval () == - np .inf
@@ -73,14 +74,15 @@ def test_continuous(self):
73
74
def test_discrete (self ):
74
75
with pm .Model () as model :
75
76
dist = pm .Poisson .dist (mu = 4 )
76
- with warnings .catch_warnings ():
77
- warnings .filterwarnings (
78
- "ignore" , "invalid value encountered in add" , RuntimeWarning
79
- )
80
- UnboundedPoisson = pm .Bound ("unbound" , dist )
81
- LowerPoisson = pm .Bound ("lower" , dist , lower = 1 )
82
- UpperPoisson = pm .Bound ("upper" , dist , upper = 10 )
83
- BoundedPoisson = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 )
77
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
78
+ with warnings .catch_warnings ():
79
+ warnings .filterwarnings (
80
+ "ignore" , "invalid value encountered in add" , RuntimeWarning
81
+ )
82
+ UnboundedPoisson = pm .Bound ("unbound" , dist )
83
+ LowerPoisson = pm .Bound ("lower" , dist , lower = 1 )
84
+ UpperPoisson = pm .Bound ("upper" , dist , upper = 10 )
85
+ BoundedPoisson = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 )
84
86
85
87
assert joint_logp (LowerPoisson , 0 ).eval () == - np .inf
86
88
assert joint_logp (UpperPoisson , 11 ).eval () == - np .inf
@@ -118,8 +120,9 @@ def test_arguments_checks(self):
118
120
msg = "Observed Bound distributions are not supported"
119
121
with pm .Model () as m :
120
122
x = pm .Normal ("x" , 0 , 1 )
121
- with pytest .raises (ValueError , match = msg ):
122
- pm .Bound ("bound" , x , observed = 5 )
123
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
124
+ with pytest .raises (ValueError , match = msg ):
125
+ pm .Bound ("bound" , x , observed = 5 )
123
126
124
127
msg = "Cannot transform discrete variable."
125
128
with pm .Model () as m :
@@ -128,52 +131,60 @@ def test_arguments_checks(self):
128
131
warnings .filterwarnings (
129
132
"ignore" , "invalid value encountered in add" , RuntimeWarning
130
133
)
131
- with pytest .raises (ValueError , match = msg ):
132
- pm .Bound ("bound" , x , transform = pm .distributions .transforms .log )
134
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
135
+ with pytest .raises (ValueError , match = msg ):
136
+ pm .Bound ("bound" , x , transform = pm .distributions .transforms .log )
133
137
134
138
msg = "Given dims do not exist in model coordinates."
135
139
with pm .Model () as m :
136
140
x = pm .Poisson .dist (0.5 )
137
- with pytest .raises (ValueError , match = msg ):
138
- pm .Bound ("bound" , x , dims = "random_dims" )
141
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
142
+ with pytest .raises (ValueError , match = msg ):
143
+ pm .Bound ("bound" , x , dims = "random_dims" )
139
144
140
145
msg = "The dist x was already registered in the current model"
141
146
with pm .Model () as m :
142
147
x = pm .Normal ("x" , 0 , 1 )
143
- with pytest .raises (ValueError , match = msg ):
144
- pm .Bound ("bound" , x )
148
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
149
+ with pytest .raises (ValueError , match = msg ):
150
+ pm .Bound ("bound" , x )
145
151
146
152
msg = "Passing a distribution class to `Bound` is no longer supported"
147
153
with pm .Model () as m :
148
- with pytest .raises (ValueError , match = msg ):
149
- pm .Bound ("bound" , pm .Normal )
154
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
155
+ with pytest .raises (ValueError , match = msg ):
156
+ pm .Bound ("bound" , pm .Normal )
150
157
151
158
msg = "Bounding of MultiVariate RVs is not yet supported"
152
159
with pm .Model () as m :
153
160
x = pm .MvNormal .dist (np .zeros (3 ), np .eye (3 ))
154
- with pytest .raises (NotImplementedError , match = msg ):
155
- pm .Bound ("bound" , x )
161
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
162
+ with pytest .raises (NotImplementedError , match = msg ):
163
+ pm .Bound ("bound" , x )
156
164
157
165
msg = "must be a Discrete or Continuous distribution subclass"
158
166
with pm .Model () as m :
159
167
x = self .create_invalid_distribution ().dist ()
160
- with pytest .raises (ValueError , match = msg ):
161
- pm .Bound ("bound" , x )
168
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
169
+ with pytest .raises (ValueError , match = msg ):
170
+ pm .Bound ("bound" , x )
162
171
163
172
def test_invalid_sampling (self ):
164
173
msg = "Cannot sample from a bounded variable"
165
174
with pm .Model () as m :
166
175
dist = pm .Normal .dist (mu = 0 , sigma = 1 )
167
- BoundedNormal = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 )
176
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
177
+ BoundedNormal = pm .Bound ("bounded" , dist , lower = 1 , upper = 10 )
168
178
with pytest .raises (NotImplementedError , match = msg ):
169
179
pm .sample_prior_predictive ()
170
180
171
181
def test_bound_shapes (self ):
172
182
with pm .Model (coords = {"sample" : np .ones ((2 , 5 ))}) as m :
173
183
dist = pm .Normal .dist (mu = 0 , sigma = 1 )
174
- bound_sized = pm .Bound ("boundedsized" , dist , lower = 1 , upper = 10 , size = (4 , 5 ))
175
- bound_shaped = pm .Bound ("boundedshaped" , dist , lower = 1 , upper = 10 , shape = (3 , 5 ))
176
- bound_dims = pm .Bound ("boundeddims" , dist , lower = 1 , upper = 10 , dims = "sample" )
184
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
185
+ bound_sized = pm .Bound ("boundedsized" , dist , lower = 1 , upper = 10 , size = (4 , 5 ))
186
+ bound_shaped = pm .Bound ("boundedshaped" , dist , lower = 1 , upper = 10 , shape = (3 , 5 ))
187
+ bound_dims = pm .Bound ("boundeddims" , dist , lower = 1 , upper = 10 , dims = "sample" )
177
188
178
189
initial_point = m .initial_point ()
179
190
dist_size = initial_point ["boundedsized_interval__" ].shape
@@ -198,13 +209,16 @@ def test_bound_dist(self):
198
209
def test_array_bound (self ):
199
210
with pm .Model () as model :
200
211
dist = pm .Normal .dist ()
201
- with warnings .catch_warnings ():
202
- warnings .filterwarnings (
203
- "ignore" , "invalid value encountered in add" , RuntimeWarning
212
+ with pytest .warns (FutureWarning , match = "Bound has been deprecated" ):
213
+ with warnings .catch_warnings ():
214
+ warnings .filterwarnings (
215
+ "ignore" , "invalid value encountered in add" , RuntimeWarning
216
+ )
217
+ LowerPoisson = pm .Bound ("lower" , dist , lower = [1 , None ], transform = None )
218
+ UpperPoisson = pm .Bound ("upper" , dist , upper = [np .inf , 10 ], transform = None )
219
+ BoundedPoisson = pm .Bound (
220
+ "bounded" , dist , lower = [1 , 2 ], upper = [9 , 10 ], transform = None
204
221
)
205
- LowerPoisson = pm .Bound ("lower" , dist , lower = [1 , None ], transform = None )
206
- UpperPoisson = pm .Bound ("upper" , dist , upper = [np .inf , 10 ], transform = None )
207
- BoundedPoisson = pm .Bound ("bounded" , dist , lower = [1 , 2 ], upper = [9 , 10 ], transform = None )
208
222
209
223
first , second = joint_logp (LowerPoisson , [0 , 0 ], sum = False )[0 ].eval ()
210
224
assert first == - np .inf
0 commit comments