Skip to content

Commit b9a3dcc

Browse files
authored
Fix some code lints found with ruff (#6545)
* Make isort happy with imports order * Avoid shadowing the logprob fn with a local name * Deal with unused error names * Remove unreachable return * Remove unread variable * Fix type checking for variational/approximations * A few typing fixes
1 parent e1d36ca commit b9a3dcc

File tree

12 files changed

+47
-54
lines changed

12 files changed

+47
-54
lines changed

pymc/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
6161
check_bounds = False in pm.Model()
6262
"""
6363
# at.all does not accept True/False, but accepts np.array(True)/np.array(False)
64-
conditions = [
64+
conditions_ = [
6565
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
6666
]
67-
all_true_scalar = at.all([at.all(cond) for cond in conditions])
67+
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
6868
return CheckParameterValue(msg)(logp, all_true_scalar)
6969

7070

pymc/distributions/distribution.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -988,21 +988,19 @@ def __new__(
988988
ndim_supp=ndim_supp,
989989
**kwargs,
990990
)
991-
else:
992-
return _CustomDist(
993-
name,
994-
*dist_params,
995-
class_name=name,
996-
random=random,
997-
logp=logp,
998-
logcdf=logcdf,
999-
moment=moment,
1000-
ndim_supp=ndim_supp,
1001-
ndims_params=ndims_params,
1002-
dtype=dtype,
1003-
**kwargs,
1004-
)
1005-
return super().__new__(cls, name, *args, **kwargs)
991+
return _CustomDist(
992+
name,
993+
*dist_params,
994+
class_name=name,
995+
random=random,
996+
logp=logp,
997+
logcdf=logcdf,
998+
moment=moment,
999+
ndim_supp=ndim_supp,
1000+
ndims_params=ndims_params,
1001+
dtype=dtype,
1002+
**kwargs,
1003+
)
10061004

10071005
@classmethod
10081006
def dist(

pymc/gp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929
# Avoid circular dependency when importing modelcontext
3030
from pymc.distributions.distribution import Distribution
31+
from pymc.model import modelcontext
3132
from pymc.pytensorf import compile_pymc, walk_model
3233

3334
_ = Distribution # keep both pylint and black happy
34-
from pymc.model import modelcontext
3535

3636
JITTER_DEFAULT = 1e-6
3737

pymc/logprob/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,16 +961,16 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
961961
if use_jacobian:
962962
assert len(values) == len(logprobs) == len(op.transforms)
963963
logprobs_jac = []
964-
for value, transform, logprob in zip(values, op.transforms, logprobs):
964+
for value, transform, logp in zip(values, op.transforms, logprobs):
965965
if transform is None:
966-
logprobs_jac.append(logprob)
966+
logprobs_jac.append(logp)
967967
continue
968968
assert isinstance(value.owner.op, TransformedVariable)
969969
original_forward_value = value.owner.inputs[1]
970970
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
971971
if value.name:
972972
jacobian.name = f"{value.name}_jacobian"
973-
logprobs_jac.append(logprob + jacobian)
973+
logprobs_jac.append(logp + jacobian)
974974
logprobs = logprobs_jac
975975

976976
return logprobs

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
195195
raise a ``TypeError`` instead of returning ``None``."""
196196
try:
197197
candidate: Optional[T] = cls.get_contexts()[-1]
198-
except IndexError as e:
198+
except IndexError:
199199
# Calling code expects to get a TypeError if the entity
200200
# is unfound, and there's too much to fix.
201201
if error_if_none:

pymc/model_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _expand(x):
7373
return []
7474

7575
parents = {
76-
get_var_name(x)
76+
VarName(get_var_name(x))
7777
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
7878
# Only consider nodes that are in the named model variables.
7979
if x.name and x.name in self._all_var_names
@@ -109,7 +109,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
109109
selected_ancestors.add(self.model.rvs_to_values[var])
110110

111111
# ordering of self._all_var_names is important
112-
return [var.name for var in selected_ancestors]
112+
return [VarName(var.name) for var in selected_ancestors]
113113

114114
def make_compute_graph(
115115
self, var_names: Optional[Iterable[VarName]] = None
@@ -230,7 +230,7 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
230230
plate_label = " x ".join(dim_labels)
231231
else:
232232
# The RV has no `dims` information.
233-
dim_labels = map(str, shape)
233+
dim_labels = [str(x) for x in shape]
234234
plate_label = " x ".join(map(str, shape))
235235
plates[plate_label].add(var_name)
236236

pymc/sampling/jax.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,10 @@
1515
import re
1616
import sys
1717

18+
from datetime import datetime
1819
from functools import partial
1920
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2021

21-
from pytensor.tensor.random.type import RandomType
22-
23-
from pymc.initial_point import StartDict
24-
from pymc.sampling.mcmc import _init_jitter
25-
26-
xla_flags = os.getenv("XLA_FLAGS", "")
27-
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
28-
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
29-
30-
from datetime import datetime
31-
3222
import arviz as az
3323
import jax
3424
import numpy as np
@@ -43,18 +33,25 @@
4333
from pytensor.link.jax.dispatch import jax_funcify
4434
from pytensor.raise_op import Assert
4535
from pytensor.tensor import TensorVariable
36+
from pytensor.tensor.random.type import RandomType
4637
from pytensor.tensor.shape import SpecifyShape
4738

4839
from pymc import Model, modelcontext
4940
from pymc.backends.arviz import find_constants, find_observations
41+
from pymc.initial_point import StartDict
5042
from pymc.logprob.utils import CheckParameterValue
43+
from pymc.sampling.mcmc import _init_jitter
5144
from pymc.util import (
5245
RandomSeed,
5346
RandomState,
5447
_get_seeds_per_chain,
5548
get_default_varnames,
5649
)
5750

51+
xla_flags_env = os.getenv("XLA_FLAGS", "")
52+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
53+
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
54+
5855
__all__ = (
5956
"get_jaxified_graph",
6057
"get_jaxified_logp",
@@ -111,7 +108,7 @@ def get_jaxified_graph(
111108
) -> List[TensorVariable]:
112109
"""Compile an PyTensor graph into an optimized JAX function"""
113110

114-
graph = _replace_shared_variables(outputs)
111+
graph = _replace_shared_variables(outputs) if outputs is not None else None
115112

116113
fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
117114
# We need to add a Supervisor to the fgraph to be able to run the
@@ -254,12 +251,10 @@ def _get_batched_jittered_initial_points(
254251
jitter=jitter,
255252
jitter_max_retries=jitter_max_retries,
256253
)
257-
initial_points = [list(initial_point.values()) for initial_point in initial_points]
254+
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
258255
if chains == 1:
259-
initial_points = initial_points[0]
260-
else:
261-
initial_points = [np.stack(init_state) for init_state in zip(*initial_points)]
262-
return initial_points
256+
return initial_points_values[0]
257+
return [np.stack(init_state) for init_state in zip(*initial_points_values)]
263258

264259

265260
def _update_coords_and_dims(

pymc/step_methods/hmc/integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def step(self, epsilon, state):
8282
return self._step(epsilon, state)
8383
except linalg.LinAlgError as err:
8484
msg = "LinAlgError during leapfrog step."
85-
raise IntegrationError(msg)
85+
raise IntegrationError(msg) from err
8686
except ValueError as err:
8787
# Raised by many scipy.linalg functions
8888
scipy_msg = "array must not contain infs or nans"

pymc/step_methods/hmc/nuts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def _hamiltonian_step(self, start, p0, step_size):
210210
def competence(var, has_grad):
211211
"""Check how appropriate this class is for sampling a random variable."""
212212

213-
dist = getattr(var.owner, "op", None)
214213
if var.dtype in continuous_types and has_grad:
215214
return Competence.PREFERRED
216215
return Competence.INCOMPATIBLE

pymc/variational/approximations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from typing import Optional
1515

1616
import numpy as np
1717
import pytensor
@@ -331,7 +331,7 @@ def sample_approx(approx, draws=100, include_transformed=True):
331331
class SingleGroupApproximation(Approximation):
332332
"""Base class for Single Group Approximation"""
333333

334-
_group_class = None
334+
_group_class: Optional[type] = None
335335

336336
def __init__(self, *args, **kwargs):
337337
groups = [self._group_class(None, *args, **kwargs)]

pymc/variational/opvi.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
import itertools
5252
import warnings
5353

54+
from typing import Any
55+
5456
import numpy as np
5557
import pytensor
5658
import pytensor.tensor as at
@@ -673,11 +675,11 @@ class Group(WithMemoization):
673675
initial_dist_map = 0.0
674676

675677
# for handy access using class methods
676-
__param_spec__ = dict()
678+
__param_spec__: dict = dict()
677679
short_name = ""
678-
alias_names = frozenset()
679-
__param_registry = dict()
680-
__name_registry = dict()
680+
alias_names: frozenset[str] = frozenset()
681+
__param_registry: dict[frozenset, Any] = dict()
682+
__name_registry: dict[str, Any] = dict()
681683

682684
@classmethod
683685
def register(cls, sbcls):
@@ -1552,11 +1554,11 @@ def sample(
15521554
finally:
15531555
trace.close()
15541556

1555-
trace = MultiTrace([trace])
1557+
multi_trace = MultiTrace([trace])
15561558
if not return_inferencedata:
1557-
return trace
1559+
return multi_trace
15581560
else:
1559-
return pm.to_inference_data(trace, model=self.model, **kwargs)
1561+
return pm.to_inference_data(multi_trace, model=self.model, **kwargs)
15601562

15611563
@property
15621564
def ndim(self):

scripts/run_mypy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
pymc/printing.py
4343
pymc/pytensorf.py
4444
pymc/sampling/jax.py
45-
pymc/variational/approximations.py
4645
pymc/variational/opvi.py
4746
"""
4847

0 commit comments

Comments
 (0)