Skip to content

Function to optimize prior under constraints #5231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8ca3ded
Replace print statement by AttributeError
AlexAndorra Nov 30, 2021
9dc0096
pre-commit formatting
AlexAndorra Nov 30, 2021
9675e4f
Mention in release notes
AlexAndorra Nov 30, 2021
d132364
Handle 1-param and 3-param distributions
AlexAndorra Dec 1, 2021
6f9ccd4
Update tests
AlexAndorra Dec 1, 2021
fea6643
Fix some wording
AlexAndorra Dec 1, 2021
524a900
pre-commit formatting
AlexAndorra Dec 3, 2021
91174b9
Only raise UserWarning when mass_in_interval not optimal
AlexAndorra Dec 3, 2021
29741f1
Raise NotImplementedError for non-scalar params
AlexAndorra Dec 3, 2021
1ad4297
Remove pipe operator for old python versions
AlexAndorra Dec 3, 2021
a708e6d
Update tests
AlexAndorra Dec 3, 2021
e1c5125
Add test with discrete distrib & wrap in pytest.warns(None)
AlexAndorra Dec 7, 2021
bc9b543
Remove pipe operator for good
AlexAndorra Dec 7, 2021
18ad975
Fix TypeError in dist_params
AlexAndorra Dec 7, 2021
e92d6d8
Relax tolerance for tests
AlexAndorra Dec 7, 2021
94b406b
Force float64 config in find_optim_prior
AlexAndorra Dec 14, 2021
76dbb1f
Rename file name to func_utils.py
AlexAndorra Dec 14, 2021
53bfc00
Replace print statement by AttributeError
AlexAndorra Nov 30, 2021
77a0bb1
pre-commit formatting
AlexAndorra Nov 30, 2021
fd5f498
Mention in release notes
AlexAndorra Nov 30, 2021
171a4aa
Handle 1-param and 3-param distributions
AlexAndorra Dec 1, 2021
36b95cb
Update tests
AlexAndorra Dec 1, 2021
55138d9
Fix some wording
AlexAndorra Dec 1, 2021
4bed2cd
pre-commit formatting
AlexAndorra Dec 3, 2021
02d117b
Only raise UserWarning when mass_in_interval not optimal
AlexAndorra Dec 3, 2021
7742571
Raise NotImplementedError for non-scalar params
AlexAndorra Dec 3, 2021
8a6e0e7
Remove pipe operator for old python versions
AlexAndorra Dec 3, 2021
602391b
Update tests
AlexAndorra Dec 3, 2021
9bb14a3
Add test with discrete distrib & wrap in pytest.warns(None)
AlexAndorra Dec 7, 2021
ab0ef0f
Remove pipe operator for good
AlexAndorra Dec 7, 2021
58f5d56
Fix TypeError in dist_params
AlexAndorra Dec 7, 2021
a6c7f0d
Relax tolerance for tests
AlexAndorra Dec 7, 2021
c9c24d6
Force float64 config in find_optim_prior
AlexAndorra Dec 14, 2021
c75f8c9
Rename file name to func_utils.py
AlexAndorra Dec 14, 2021
3ffd7ff
Change optimization error function and refactor tests
ricardoV94 Dec 16, 2021
a1a6bdf
Use aesaraf.compile_pymc
ricardoV94 Dec 21, 2021
7cd0e55
Merge branch 'optim-prior' of https://github.com/pymc-devs/pymc into …
AlexAndorra Dec 22, 2021
1d868fa
Add and test AssertionError for mass value
AlexAndorra Dec 22, 2021
063bc96
Fix type error in warning message
AlexAndorra Dec 23, 2021
cb7908c
Split up Poisson test
AlexAndorra Dec 24, 2021
16ed438
Use scipy default for Exponential and reactivate tests
AlexAndorra Dec 24, 2021
1b84e18
Refactor Poisson tests
AlexAndorra Dec 24, 2021
6ea7861
Reduce Poisson test tol to 1% for float32
AlexAndorra Dec 25, 2021
d63b652
Remove Exponential logic
AlexAndorra Dec 27, 2021
37e6251
Rename function
AlexAndorra Dec 27, 2021
b912ac6
Refactor test functions names
AlexAndorra Dec 28, 2021
d4bce39
Use more precise exception for gradient
AlexAndorra Dec 30, 2021
9a51289
Don't catch TypeError
AlexAndorra Jan 3, 2022
90a88ff
Merge branch 'main' into optim-prior
AlexAndorra Jan 3, 2022
8b9ae6e
Remove specific Poisson test
AlexAndorra Jan 3, 2022
d53154a
Remove typo from old Poisson test
AlexAndorra Jan 3, 2022
1f42835
Put tests for constrained priors into their own file
AlexAndorra Jan 3, 2022
bad236c
Add code examples in docstrings
AlexAndorra Jan 4, 2022
d89e375
Merge branch 'main' into optim-prior
AlexAndorra Jan 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions pymc/find_optim_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import Dict, Optional

Expand Down Expand Up @@ -68,18 +69,11 @@ def find_optim_prior(
The optimized distribution parameters as a dictionary with the parameters'
name as key and the optimized value as value.
"""
if len(init_guess) > 2:
if (fixed_params is None) or (len(fixed_params) < (len(pm_dist.rv_op.ndims_params) - 2)):
raise NotImplementedError(
"This function can only optimize two parameters. "
f"{pm_dist} has {len(pm_dist.rv_op.ndims_params)} parameters. "
f"You need to fix {len(pm_dist.rv_op.ndims_params) - 2} parameters in the "
"`fixed_params` dictionary."
)
elif (len(init_guess) < 2) and (len(init_guess) < len(pm_dist.rv_op.ndims_params)):
raise ValueError(
f"{pm_dist} has {len(pm_dist.rv_op.ndims_params)} parameters, but you provided only "
f"{len(init_guess)} initial guess. You need to provide 2."
# exit when any parameter is not scalar:
if np.any(np.asarray(pm_dist.rv_op.ndims_params) != 0):
raise NotImplementedError(
"`pm.find_optim_prior` does not work with non-scalar parameters yet.\n"
"Feel free to open a pull request on PyMC repo if you really need this feature."
)

dist_params = aet.vector("dist_params")
Expand All @@ -99,12 +93,11 @@ def find_optim_prior(
except AttributeError:
raise AttributeError(
f"You cannot use `find_optim_prior` with {pm_dist} -- it doesn't have a logcdf "
"method yet. Open an issue or, even better, a pull request on PyMC repo if you really "
"method yet.\nOpen an issue or, even better, a pull request on PyMC repo if you really "
"need it."
)

alpha = 1 - mass
out = [logcdf_lower - np.log(alpha / 2), logcdf_upper - np.log(1 - alpha / 2)]
out = pm.math.logdiffexp(logcdf_upper, logcdf_lower) - np.log(mass)
logcdf = aesara.function([dist_params, lower_, upper_], out)

try:
Expand All @@ -119,11 +112,24 @@ def find_optim_prior(
if not opt.success:
raise ValueError("Optimization of parameters failed.")

# save optimal parameters
opt_params = {
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
}
if fixed_params is not None:
return {
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
} | fixed_params
else:
return {
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
}
opt_params.update(fixed_params)

# check mass in interval is not too far from `mass`
opt_dist = pm_dist.dist(**opt_params)
mass_in_interval = (
pm.math.exp(pm.logcdf(opt_dist, upper)) - pm.math.exp(pm.logcdf(opt_dist, lower))
).eval()
if (np.abs(mass_in_interval - mass)) >= 0.01:
warnings.warn(
f"Final optimization has {mass_in_interval * 100:.0f}% of probability mass between "
f"{lower} and {upper} instead of the requested {mass * 100:.0f}%.\n"
"You may need to use a more flexible distribution, change the fixed parameters in the "
"`fixed_params` dictionary, or provide better initial guesses."
)

return opt_params
61 changes: 44 additions & 17 deletions pymc/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,25 +147,23 @@ def fn(a=UNSET):
def test_find_optim_prior():
MASS = 0.95

# normal case
# Gamma, normal case
opt_params = pm.find_optim_prior(
pm.Gamma, lower=0.1, upper=0.4, mass=MASS, init_guess={"alpha": 1, "beta": 10}
)
np.testing.assert_allclose(np.asarray(opt_params.values()), np.array([8.47481597, 37.65435601]))
np.testing.assert_allclose(
list(opt_params.values()), np.array([8.506023352404027, 37.59626616198404])
)

# normal case, other distribution
# Normal, normal case
opt_params = pm.find_optim_prior(
pm.Normal, lower=155, upper=180, mass=MASS, init_guess={"mu": 170, "sigma": 3}
)
np.testing.assert_allclose(np.asarray(opt_params.values()), np.array([167.5000001, 6.37766828]))

# 1-param case
opt_params = pm.find_optim_prior(
pm.Exponential, lower=0.1, upper=0.4, mass=MASS, init_guess={"lam": 10}
np.testing.assert_allclose(
list(opt_params.values()), np.array([170.76059047372624, 5.542895384602784])
)
np.testing.assert_allclose(np.asarray(opt_params.values()), np.array([0.79929324]))

# 3-param case
# Student, works as expected
opt_params = pm.find_optim_prior(
pm.StudentT,
lower=0.1,
Expand All @@ -174,9 +172,36 @@ def test_find_optim_prior():
init_guess={"mu": 170, "sigma": 3},
fixed_params={"nu": 7},
)
np.testing.assert_allclose(np.asarray(opt_params.values()), np.array([0.25, 0.06343503]))
assert "nu" in opt_params
np.testing.assert_allclose(
list(opt_params.values()), np.array([0.24995405785756986, 0.06343501657095188, 7])
)

with pytest.raises(ValueError, match="parameters, but you provided only"):
# Student not deterministic but without warning
with pytest.warns(None) as record:
pm.find_optim_prior(
pm.StudentT,
lower=0,
upper=1,
mass=MASS,
init_guess={"mu": 5, "sigma": 2, "nu": 7},
)
assert len(record) == 0

# Exponential without warning
with pytest.warns(None) as record:
opt_params = pm.find_optim_prior(
pm.Exponential, lower=0, upper=1, mass=MASS, init_guess={"lam": 1}
)
assert len(record) == 0
np.testing.assert_allclose(list(opt_params.values()), np.array([2.9957322673241604]))

# Exponential too constraining
with pytest.warns(UserWarning, match="instead of the requested 95%"):
pm.find_optim_prior(pm.Exponential, lower=0.1, upper=1, mass=MASS, init_guess={"lam": 1})

# Gamma too constraining
with pytest.warns(UserWarning, match="instead of the requested 95%"):
pm.find_optim_prior(
pm.Gamma,
lower=0.1,
Expand All @@ -186,16 +211,18 @@ def test_find_optim_prior():
fixed_params={"beta": 10},
)

# missing param
with pytest.raises(TypeError, match="required positional argument"):
pm.find_optim_prior(
pm.StudentT, lower=0.1, upper=0.4, mass=MASS, init_guess={"mu": 170, "sigma": 3}
)

with pytest.raises(NotImplementedError, match="This function can only optimize two parameters"):
# non-scalar params
with pytest.raises(NotImplementedError, match="does not work with non-scalar parameters yet"):
pm.find_optim_prior(
pm.StudentT,
lower=0.1,
upper=0.4,
pm.MvNormal,
lower=0,
upper=1,
mass=MASS,
init_guess={"mu": 170, "sigma": 3, "nu": 7},
init_guess={"mu": 5, "cov": np.asarray([[1, 0.2], [0.2, 1]])},
)