Skip to content

Commit 4072488

Browse files
committed
Add deprecation warning for Bound
1 parent a1f9d00 commit 4072488

File tree

2 files changed

+69
-46
lines changed

2 files changed

+69
-46
lines changed

pymc/distributions/bound.py

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

1516
import aesara.tensor as at
1617
import numpy as np
@@ -182,6 +183,14 @@ def __new__(
182183
**kwargs,
183184
):
184185

186+
warnings.warn(
187+
"Bound has been deprecated in favor of Truncated, and will be removed in a "
188+
"future release. If Truncated is not an option, Bound can be implemented by"
189+
"adding an IntervalTransform between lower and upper to a continuous "
190+
"variable. A Potential that returns negative infinity for values outside "
191+
"of the bounds can be used for discrete variables.",
192+
FutureWarning,
193+
)
185194
cls._argument_checks(dist, **kwargs)
186195

187196
if dims is not None:

pymc/tests/distributions/test_bound.py

+60-46
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ class TestBound:
3131
def test_continuous(self):
3232
with pm.Model() as model:
3333
dist = pm.Normal.dist(mu=0, sigma=1)
34-
with warnings.catch_warnings():
35-
warnings.filterwarnings(
36-
"ignore", "invalid value encountered in add", RuntimeWarning
37-
)
38-
UnboundedNormal = pm.Bound("unbound", dist, transform=None)
39-
InfBoundedNormal = pm.Bound(
40-
"infbound", dist, lower=-np.inf, upper=np.inf, transform=None
41-
)
42-
LowerNormal = pm.Bound("lower", dist, lower=0, transform=None)
43-
UpperNormal = pm.Bound("upper", dist, upper=0, transform=None)
44-
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10, transform=None)
45-
LowerNormalTransform = pm.Bound("lowertrans", dist, lower=1)
46-
UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10)
47-
BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10)
34+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
35+
with warnings.catch_warnings():
36+
warnings.filterwarnings(
37+
"ignore", "invalid value encountered in add", RuntimeWarning
38+
)
39+
UnboundedNormal = pm.Bound("unbound", dist, transform=None)
40+
InfBoundedNormal = pm.Bound(
41+
"infbound", dist, lower=-np.inf, upper=np.inf, transform=None
42+
)
43+
LowerNormal = pm.Bound("lower", dist, lower=0, transform=None)
44+
UpperNormal = pm.Bound("upper", dist, upper=0, transform=None)
45+
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10, transform=None)
46+
LowerNormalTransform = pm.Bound("lowertrans", dist, lower=1)
47+
UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10)
48+
BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10)
4849

4950
assert joint_logp(LowerNormal, -1).eval() == -np.inf
5051
assert joint_logp(UpperNormal, 1).eval() == -np.inf
@@ -73,14 +74,15 @@ def test_continuous(self):
7374
def test_discrete(self):
7475
with pm.Model() as model:
7576
dist = pm.Poisson.dist(mu=4)
76-
with warnings.catch_warnings():
77-
warnings.filterwarnings(
78-
"ignore", "invalid value encountered in add", RuntimeWarning
79-
)
80-
UnboundedPoisson = pm.Bound("unbound", dist)
81-
LowerPoisson = pm.Bound("lower", dist, lower=1)
82-
UpperPoisson = pm.Bound("upper", dist, upper=10)
83-
BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10)
77+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
78+
with warnings.catch_warnings():
79+
warnings.filterwarnings(
80+
"ignore", "invalid value encountered in add", RuntimeWarning
81+
)
82+
UnboundedPoisson = pm.Bound("unbound", dist)
83+
LowerPoisson = pm.Bound("lower", dist, lower=1)
84+
UpperPoisson = pm.Bound("upper", dist, upper=10)
85+
BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10)
8486

8587
assert joint_logp(LowerPoisson, 0).eval() == -np.inf
8688
assert joint_logp(UpperPoisson, 11).eval() == -np.inf
@@ -118,8 +120,9 @@ def test_arguments_checks(self):
118120
msg = "Observed Bound distributions are not supported"
119121
with pm.Model() as m:
120122
x = pm.Normal("x", 0, 1)
121-
with pytest.raises(ValueError, match=msg):
122-
pm.Bound("bound", x, observed=5)
123+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
124+
with pytest.raises(ValueError, match=msg):
125+
pm.Bound("bound", x, observed=5)
123126

124127
msg = "Cannot transform discrete variable."
125128
with pm.Model() as m:
@@ -128,52 +131,60 @@ def test_arguments_checks(self):
128131
warnings.filterwarnings(
129132
"ignore", "invalid value encountered in add", RuntimeWarning
130133
)
131-
with pytest.raises(ValueError, match=msg):
132-
pm.Bound("bound", x, transform=pm.distributions.transforms.log)
134+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
135+
with pytest.raises(ValueError, match=msg):
136+
pm.Bound("bound", x, transform=pm.distributions.transforms.log)
133137

134138
msg = "Given dims do not exist in model coordinates."
135139
with pm.Model() as m:
136140
x = pm.Poisson.dist(0.5)
137-
with pytest.raises(ValueError, match=msg):
138-
pm.Bound("bound", x, dims="random_dims")
141+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
142+
with pytest.raises(ValueError, match=msg):
143+
pm.Bound("bound", x, dims="random_dims")
139144

140145
msg = "The dist x was already registered in the current model"
141146
with pm.Model() as m:
142147
x = pm.Normal("x", 0, 1)
143-
with pytest.raises(ValueError, match=msg):
144-
pm.Bound("bound", x)
148+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
149+
with pytest.raises(ValueError, match=msg):
150+
pm.Bound("bound", x)
145151

146152
msg = "Passing a distribution class to `Bound` is no longer supported"
147153
with pm.Model() as m:
148-
with pytest.raises(ValueError, match=msg):
149-
pm.Bound("bound", pm.Normal)
154+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
155+
with pytest.raises(ValueError, match=msg):
156+
pm.Bound("bound", pm.Normal)
150157

151158
msg = "Bounding of MultiVariate RVs is not yet supported"
152159
with pm.Model() as m:
153160
x = pm.MvNormal.dist(np.zeros(3), np.eye(3))
154-
with pytest.raises(NotImplementedError, match=msg):
155-
pm.Bound("bound", x)
161+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
162+
with pytest.raises(NotImplementedError, match=msg):
163+
pm.Bound("bound", x)
156164

157165
msg = "must be a Discrete or Continuous distribution subclass"
158166
with pm.Model() as m:
159167
x = self.create_invalid_distribution().dist()
160-
with pytest.raises(ValueError, match=msg):
161-
pm.Bound("bound", x)
168+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
169+
with pytest.raises(ValueError, match=msg):
170+
pm.Bound("bound", x)
162171

163172
def test_invalid_sampling(self):
164173
msg = "Cannot sample from a bounded variable"
165174
with pm.Model() as m:
166175
dist = pm.Normal.dist(mu=0, sigma=1)
167-
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10)
176+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
177+
BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10)
168178
with pytest.raises(NotImplementedError, match=msg):
169179
pm.sample_prior_predictive()
170180

171181
def test_bound_shapes(self):
172182
with pm.Model(coords={"sample": np.ones((2, 5))}) as m:
173183
dist = pm.Normal.dist(mu=0, sigma=1)
174-
bound_sized = pm.Bound("boundedsized", dist, lower=1, upper=10, size=(4, 5))
175-
bound_shaped = pm.Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5))
176-
bound_dims = pm.Bound("boundeddims", dist, lower=1, upper=10, dims="sample")
184+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
185+
bound_sized = pm.Bound("boundedsized", dist, lower=1, upper=10, size=(4, 5))
186+
bound_shaped = pm.Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5))
187+
bound_dims = pm.Bound("boundeddims", dist, lower=1, upper=10, dims="sample")
177188

178189
initial_point = m.initial_point()
179190
dist_size = initial_point["boundedsized_interval__"].shape
@@ -198,13 +209,16 @@ def test_bound_dist(self):
198209
def test_array_bound(self):
199210
with pm.Model() as model:
200211
dist = pm.Normal.dist()
201-
with warnings.catch_warnings():
202-
warnings.filterwarnings(
203-
"ignore", "invalid value encountered in add", RuntimeWarning
212+
with pytest.warns(FutureWarning, match="Bound has been deprecated"):
213+
with warnings.catch_warnings():
214+
warnings.filterwarnings(
215+
"ignore", "invalid value encountered in add", RuntimeWarning
216+
)
217+
LowerPoisson = pm.Bound("lower", dist, lower=[1, None], transform=None)
218+
UpperPoisson = pm.Bound("upper", dist, upper=[np.inf, 10], transform=None)
219+
BoundedPoisson = pm.Bound(
220+
"bounded", dist, lower=[1, 2], upper=[9, 10], transform=None
204221
)
205-
LowerPoisson = pm.Bound("lower", dist, lower=[1, None], transform=None)
206-
UpperPoisson = pm.Bound("upper", dist, upper=[np.inf, 10], transform=None)
207-
BoundedPoisson = pm.Bound("bounded", dist, lower=[1, 2], upper=[9, 10], transform=None)
208222

209223
first, second = joint_logp(LowerPoisson, [0, 0], sum=False)[0].eval()
210224
assert first == -np.inf

0 commit comments

Comments
 (0)