Skip to content

Commit e7e8a54

Browse files
committed
Move all custom Exceptions to exceptions.py
1 parent 4856e22 commit e7e8a54

29 files changed

+101
-98
lines changed

pymc/backends/base.py

-4
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@
4545
logger = logging.getLogger(__name__)
4646

4747

48-
class BackendError(Exception):
49-
pass
50-
51-
5248
class IBaseTrace(ABC, Sized):
5349
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""
5450

pymc/data.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import pymc as pm
3838

39+
from pymc.exceptions import ShapeError
3940
from pymc.pytensorf import convert_observed_data
4041

4142
__all__ = [
@@ -237,7 +238,7 @@ def determine_coords(
237238

238239
if isinstance(value, np.ndarray) and dims is not None:
239240
if len(dims) != value.ndim:
240-
raise pm.exceptions.ShapeError(
241+
raise ShapeError(
241242
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
242243
actual=value.shape,
243244
expected=value.ndim,
@@ -445,7 +446,7 @@ def Data(
445446
if isinstance(dims, str):
446447
dims = (dims,)
447448
if not (dims is None or len(dims) == x.ndim):
448-
raise pm.exceptions.ShapeError(
449+
raise ShapeError(
449450
"Length of `dims` must match the dimensions of the dataset.",
450451
actual=len(dims),
451452
expected=x.ndim,

pymc/exceptions.py

+52-7
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__all__ = [
16-
"SamplingError",
17-
"ImputationWarning",
18-
"ShapeWarning",
19-
"ShapeError",
20-
]
21-
2215

2316
class SamplingError(RuntimeError):
2417
pass
@@ -74,3 +67,55 @@ class NotConstantValueError(ValueError):
7467

7568
class BlockModelAccessError(RuntimeError):
7669
pass
70+
71+
72+
class ParallelSamplingError(Exception):
73+
def __init__(self, message, chain):
74+
super().__init__(message)
75+
self._chain = chain
76+
77+
78+
class RemoteTraceback(Exception):
79+
def __init__(self, tb):
80+
self.tb = tb
81+
82+
def __str__(self):
83+
return self.tb
84+
85+
86+
class VariationalInferenceError(Exception):
87+
"""Exception for VI specific cases"""
88+
89+
90+
class NotImplementedInference(VariationalInferenceError, NotImplementedError):
91+
"""Marking non functional parts of code"""
92+
93+
94+
class ExplicitInferenceError(VariationalInferenceError, TypeError):
95+
"""Exception for bad explicit inference"""
96+
97+
98+
class ParametrizationError(VariationalInferenceError, ValueError):
99+
"""Error raised in case of bad parametrization"""
100+
101+
102+
class GroupError(VariationalInferenceError, TypeError):
103+
"""Error related to VI groups"""
104+
105+
106+
class IntegrationError(RuntimeError):
107+
pass
108+
109+
110+
class PositiveDefiniteError(ValueError):
111+
def __init__(self, msg, idx):
112+
super().__init__(msg)
113+
self.idx = idx
114+
self.msg = msg
115+
116+
def __str__(self):
117+
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."
118+
119+
120+
class ParameterValueError(ValueError):
121+
"""Exception for invalid parameters values in logprob graphs"""

pymc/logprob/utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from pytensor.tensor.random.op import RandomVariable
6464
from pytensor.tensor.variable import TensorVariable
6565

66+
from pymc.exceptions import ParameterValueError
6667
from pymc.logprob.abstract import MeasurableVariable, _logprob
6768
from pymc.util import makeiter
6869

@@ -231,10 +232,6 @@ def check_potential_measurability(
231232
return False
232233

233234

234-
class ParameterValueError(ValueError):
235-
"""Exception for invalid parameters values in logprob graphs"""
236-
237-
238235
class CheckParameterValue(CheckAndRaise):
239236
"""Implements a parameter value check in a logprob graph.
240237

pymc/model/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@
5959
from pymc.exceptions import (
6060
BlockModelAccessError,
6161
ImputationWarning,
62+
ParameterValueError,
6263
SamplingError,
6364
ShapeError,
6465
ShapeWarning,
6566
)
6667
from pymc.initial_point import make_initial_point_fn
6768
from pymc.logprob.basic import transformed_conditional_logp
68-
from pymc.logprob.utils import ParameterValueError
6969
from pymc.model_graph import model_to_graphviz
7070
from pymc.pytensorf import (
7171
PointFunc,

pymc/sampling/mcmc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
5858
from pymc.blocking import DictToArrayBijection
59-
from pymc.exceptions import SamplingError
59+
from pymc.exceptions import ParallelSamplingError, SamplingError
6060
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
6161
from pymc.model import Model, modelcontext
6262
from pymc.sampling.parallel import Draw, _cpu_count
@@ -1199,7 +1199,7 @@ def _mp_sample(
11991199
if callback is not None:
12001200
callback(trace=strace, draw=draw)
12011201

1202-
except ps.ParallelSamplingError as error:
1202+
except ParallelSamplingError as error:
12031203
strace = traces[error._chain]
12041204
for strace in traces:
12051205
strace.close()

pymc/sampling/parallel.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,13 @@
2929
from fastprogress.fastprogress import progress_bar
3030

3131
from pymc.blocking import DictToArrayBijection
32-
from pymc.exceptions import SamplingError
32+
from pymc.exceptions import ParallelSamplingError, RemoteTraceback, SamplingError
3333
from pymc.util import RandomSeed
3434

3535
logger = logging.getLogger(__name__)
3636

3737

38-
class ParallelSamplingError(Exception):
39-
def __init__(self, message, chain):
40-
super().__init__(message)
41-
self._chain = chain
42-
43-
4438
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
45-
class RemoteTraceback(Exception):
46-
def __init__(self, tb):
47-
self.tb = tb
48-
49-
def __str__(self):
50-
return self.tb
5139

5240

5341
class ExceptionWithTraceback:

pymc/step_methods/hmc/base_hmc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import numpy as np
2424

2525
from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
26-
from pymc.exceptions import SamplingError
26+
from pymc.exceptions import IntegrationError, SamplingError
2727
from pymc.model import Point, modelcontext
2828
from pymc.pytensorf import floatX
2929
from pymc.stats.convergence import SamplerWarning, WarningType
3030
from pymc.step_methods import step_sizes
3131
from pymc.step_methods.arraystep import GradientSharedStep
3232
from pymc.step_methods.hmc import integration
33-
from pymc.step_methods.hmc.integration import IntegrationError, State
33+
from pymc.step_methods.hmc.integration import State
3434
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3535
from pymc.tuning import guess_scaling
3636
from pymc.util import get_value_vars_from_user_vars

pymc/step_methods/hmc/hmc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919
import numpy as np
2020

21+
from pymc.exceptions import IntegrationError
2122
from pymc.stats.convergence import SamplerWarning
2223
from pymc.step_methods.compound import Competence
2324
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
24-
from pymc.step_methods.hmc.integration import IntegrationError, State
25+
from pymc.step_methods.hmc.integration import State
2526
from pymc.vartypes import discrete_types
2627

2728
__all__ = ["HamiltonianMC"]

pymc/step_methods/hmc/integration.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from scipy import linalg
2020

2121
from pymc.blocking import RaveledVars
22+
from pymc.exceptions import IntegrationError
2223
from pymc.step_methods.hmc.quadpotential import QuadPotential
2324

2425

@@ -32,10 +33,6 @@ class State(NamedTuple):
3233
index_in_trajectory: int
3334

3435

35-
class IntegrationError(RuntimeError):
36-
pass
37-
38-
3936
class CpuLeapfrogIntegrator:
4037
def __init__(self, potential: QuadPotential, logp_dlogp_func):
4138
"""Leapfrog integrator using CPU."""

pymc/step_methods/hmc/nuts.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
import numpy as np
2020

21+
from pymc.exceptions import IntegrationError
2122
from pymc.math import logbern
2223
from pymc.pytensorf import floatX
2324
from pymc.stats.convergence import SamplerWarning
2425
from pymc.step_methods.compound import Competence
2526
from pymc.step_methods.hmc import integration
2627
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
27-
from pymc.step_methods.hmc.integration import IntegrationError, State
28+
from pymc.step_methods.hmc.integration import State
2829
from pymc.vartypes import continuous_types
2930

3031
__all__ = ["NUTS"]

pymc/step_methods/hmc/quadpotential.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from numpy.random import normal
2626
from scipy.sparse import issparse
2727

28+
from pymc.exceptions import PositiveDefiniteError
2829
from pymc.pytensorf import floatX
2930

3031
__all__ = [
@@ -87,16 +88,6 @@ def partial_check_positive_definite(C):
8788
raise PositiveDefiniteError("Simple check failed. Diagonal contains negatives", i)
8889

8990

90-
class PositiveDefiniteError(ValueError):
91-
def __init__(self, msg, idx):
92-
super().__init__(msg)
93-
self.idx = idx
94-
self.msg = msg
95-
96-
def __str__(self):
97-
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."
98-
99-
10091
class QuadPotential:
10192
dtype: np.dtype
10293

pymc/testing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434

3535
from pymc.distributions.distribution import Distribution
3636
from pymc.distributions.shape_utils import change_dist_size
37+
from pymc.exceptions import ParameterValueError
3738
from pymc.initial_point import make_initial_point_fn
3839
from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp
39-
from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph
40+
from pymc.logprob.utils import find_rvs_in_graph
4041
from pymc.pytensorf import (
4142
compile_pymc,
4243
floatX,

pymc/variational/approximations.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626

2727
from pymc.blocking import DictToArrayBijection
2828
from pymc.distributions.dist_math import rho2sigma
29+
from pymc.exceptions import NotImplementedInference
2930
from pymc.util import makeiter
30-
from pymc.variational import opvi
3131
from pymc.variational.opvi import (
3232
Approximation,
3333
Group,
34-
NotImplementedInference,
3534
_known_scan_ignored_inputs,
3635
node_property,
3736
)
@@ -212,7 +211,7 @@ def __init_group__(self, group):
212211
def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
213212
if trace is None:
214213
if size is None:
215-
raise opvi.ParametrizationError("Need `trace` or `size` to initialize")
214+
raise pymc.exceptions.ParametrizationError("Need `trace` or `size` to initialize")
216215
else:
217216
start = self._prepare_start(start)
218217
# Initialize particles

pymc/variational/operators.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
import pymc as pm
2121

22+
from pymc.exceptions import NotImplementedInference, ParametrizationError
2223
from pymc.variational import opvi
2324
from pymc.variational.opvi import (
24-
NotImplementedInference,
2525
ObjectiveFunction,
2626
Operator,
2727
_known_scan_ignored_inputs,
@@ -81,7 +81,7 @@ class KSDObjective(ObjectiveFunction):
8181

8282
def __init__(self, op: KSD, tf: opvi.TestFunction):
8383
if not isinstance(op, KSD):
84-
raise opvi.ParametrizationError("Op should be KSD")
84+
raise ParametrizationError("Op should be KSD")
8585
super().__init__(op, tf)
8686

8787
@pytensor.config.change_flags(compute_test_value="off")

pymc/variational/opvi.py

+6-24
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@
6868
from pymc.backends.base import MultiTrace
6969
from pymc.backends.ndarray import NDArray
7070
from pymc.blocking import DictToArrayBijection
71+
from pymc.exceptions import (
72+
ExplicitInferenceError,
73+
GroupError,
74+
ParametrizationError,
75+
VariationalInferenceError,
76+
)
7177
from pymc.initial_point import make_initial_point_fn
7278
from pymc.model import modelcontext
7379
from pymc.pytensorf import (
@@ -91,30 +97,6 @@
9197
__all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"]
9298

9399

94-
class VariationalInferenceError(Exception):
95-
"""Exception for VI specific cases"""
96-
97-
98-
class NotImplementedInference(VariationalInferenceError, NotImplementedError):
99-
"""Marking non functional parts of code"""
100-
101-
102-
class ExplicitInferenceError(VariationalInferenceError, TypeError):
103-
"""Exception for bad explicit inference"""
104-
105-
106-
class AEVBInferenceError(VariationalInferenceError, TypeError):
107-
"""Exception for bad aevb inference"""
108-
109-
110-
class ParametrizationError(VariationalInferenceError, ValueError):
111-
"""Error raised in case of bad parametrization"""
112-
113-
114-
class GroupError(VariationalInferenceError, TypeError):
115-
"""Error related to VI groups"""
116-
117-
118100
def _known_scan_ignored_inputs(terms):
119101
# TODO: remove when scan issue with grads is fixed
120102
from pymc.data import MinibatchIndexRV

tests/distributions/test_continuous.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030
from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated
3131
from pymc.distributions.dist_math import clipped_beta_rvs
32+
from pymc.exceptions import ParameterValueError
3233
from pymc.logprob.basic import icdf, logcdf, logp
33-
from pymc.logprob.utils import ParameterValueError
3434
from pymc.pytensorf import floatX
3535
from pymc.testing import (
3636
BaseTestDistributionRandom,

tests/distributions/test_discrete.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
import pymc as pm
3030

3131
from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit
32+
from pymc.exceptions import ParameterValueError
3233
from pymc.logprob.basic import icdf, logcdf, logp
33-
from pymc.logprob.utils import ParameterValueError
3434
from pymc.pytensorf import floatX
3535
from pymc.testing import (
3636
BaseTestDistributionRandom,

0 commit comments

Comments
 (0)