Skip to content

Apply ruff/flake8-simplify rules (SIM) #9727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _find_absolute_paths(
['common.py']
"""
if isinstance(paths, str):
if is_remote_uri(paths) and kwargs.get("engine", None) == "zarr":
if is_remote_uri(paths) and kwargs.get("engine") == "zarr":
try:
from fsspec.core import get_fs_token_paths
except ImportError as e:
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def sort_backends(
) -> dict[str, type[BackendEntrypoint]]:
ordered_backends_entrypoints = {}
for be_name in STANDARD_BACKENDS_ORDER:
if backend_entrypoints.get(be_name, None) is not None:
if backend_entrypoints.get(be_name) is not None:
ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name)
ordered_backends_entrypoints.update(
{name: backend_entrypoints[name] for name in sorted(backend_entrypoints)}
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def create_coords_with_default_indexes(

# extract and merge coordinates and indexes from input DataArrays
if dataarray_coords:
prioritized = {k: (v, indexes.get(k, None)) for k, v in variables.items()}
prioritized = {k: (v, indexes.get(k)) for k, v in variables.items()}
variables, indexes = merge_coordinates_without_align(
dataarray_coords + [new_coords],
prioritized=prioritized,
Expand Down
10 changes: 3 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def _setitem_check(self, key, value):
new_value[name] = duck_array_ops.astype(val, dtype=var_k.dtype, copy=False)

# check consistency of dimension sizes and dimension coordinates
if isinstance(value, DataArray) or isinstance(value, Dataset):
if isinstance(value, DataArray | Dataset):
align(self[key], value, join="exact", copy=False)

return new_value
Expand Down Expand Up @@ -7002,7 +7002,7 @@ def reduce(
math_scores (student) float64 24B 91.0 82.5 96.5
english_scores (student) float64 24B 91.0 80.5 94.5
"""
if kwargs.get("axis", None) is not None:
if kwargs.get("axis") is not None:
raise ValueError(
"passing 'axis' to Dataset reduce methods is ambiguous."
" Please use 'dim' instead."
Expand Down Expand Up @@ -10036,11 +10036,7 @@ def curvefit(
else:
reduce_dims_ = list(reduce_dims)

if (
isinstance(coords, str)
or isinstance(coords, DataArray)
or not isinstance(coords, Iterable)
):
if isinstance(coords, str | DataArray) or not isinstance(coords, Iterable):
coords = [coords]
coords_: Sequence[DataArray] = [
self[coord] if isinstance(coord, str) else coord for coord in coords
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def _get_interpolator(
# take higher dimensional data but scipy.interp1d can.
if (
method == "linear"
and not kwargs.get("fill_value", None) == "extrapolate"
and not kwargs.get("fill_value") == "extrapolate"
and not vectorizeable_only
):
kwargs.update(method=method)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __setitem__(self, key, value):

def _create_method(name, npmodule=np) -> Callable:
def f(values, axis=None, **kwargs):
dtype = kwargs.get("dtype", None)
dtype = kwargs.get("dtype")
bn_func = getattr(bn, name, None)

if (
Expand Down
8 changes: 2 additions & 6 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ def equivalent(first: T, second: T) -> bool:
def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool:
if len(first) != len(second):
return False
for f, s in zip(first, second, strict=True):
if not equivalent(f, s):
return False
return True
return all(equivalent(f, s) for f, s in zip(first, second, strict=True))


def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]:
Expand Down Expand Up @@ -1073,8 +1070,7 @@ def contains_only_chunked_or_numpy(obj) -> bool:

return all(
[
isinstance(var._data, ExplicitlyIndexed)
or isinstance(var._data, np.ndarray)
isinstance(var._data, ExplicitlyIndexed | np.ndarray)
or is_chunked_array(var._data)
for var in obj._variables.values()
]
Expand Down
2 changes: 1 addition & 1 deletion xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def newplotfunc(
cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data)

_add_colorbar(
primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params
primitive, ax, kwargs.get("cbar_ax"), cbar_kwargs, cmap_params
)

if add_legend_:
Expand Down
12 changes: 6 additions & 6 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def map_dataarray(

"""

if kwargs.get("cbar_ax", None) is not None:
if kwargs.get("cbar_ax") is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
Expand All @@ -363,7 +363,7 @@ def map_dataarray(
x=x,
y=y,
imshow=func.__name__ == "imshow",
rgb=kwargs.get("rgb", None),
rgb=kwargs.get("rgb"),
)

for d, ax in zip(self.name_dicts.flat, self.axs.flat, strict=True):
Expand Down Expand Up @@ -421,7 +421,7 @@ def map_plot1d(
# not sure how much that is used outside these tests.
self.data = self.data.copy()

if kwargs.get("cbar_ax", None) is not None:
if kwargs.get("cbar_ax") is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

if func.__name__ == "scatter":
Expand Down Expand Up @@ -537,8 +537,8 @@ def map_plot1d(
add_colorbar, add_legend = _determine_guide(
hueplt_norm,
sizeplt_norm,
kwargs.get("add_colorbar", None),
kwargs.get("add_legend", None),
kwargs.get("add_colorbar"),
kwargs.get("add_legend"),
# kwargs.get("add_guide", None),
# kwargs.get("hue_style", None),
)
Expand Down Expand Up @@ -622,7 +622,7 @@ def map_dataset(

kwargs["add_guide"] = False

if kwargs.get("markersize", None):
if kwargs.get("markersize"):
kwargs["size_mapping"] = _parse_size(
self.data[kwargs["markersize"]], kwargs.pop("size_norm", None)
)
Expand Down
4 changes: 2 additions & 2 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def _process_cmap_cbar_kwargs(
# Leave user to specify cmap settings for surface plots
kwargs["cmap"] = cmap
return {
k: kwargs.get(k, None)
k: kwargs.get(k)
for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
}, {}

Expand Down Expand Up @@ -1828,7 +1828,7 @@ def _guess_coords_to_plot(
default_guess, available_coords, ignore_guess_kwargs, strict=False
):
if coords_to_plot.get(k, None) is None and all(
kwargs.get(ign_kw, None) is None for ign_kw in ign_kws
kwargs.get(ign_kw) is None for ign_kw in ign_kws
):
coords_to_plot[k] = dim

Expand Down
5 changes: 1 addition & 4 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def substring_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
Return True if a substring is found anywhere in an axes
"""
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
for txt in alltxt:
if substring in txt:
return True
return False
return any(substring in txt for txt in alltxt)


def substring_not_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
Expand Down
Loading