Skip to content

MAINT: general deps bump #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,420 changes: 1,232 additions & 3,188 deletions pixi.lock

Large diffs are not rendered by default.

27 changes: 16 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ numpy = "=1.22.0"
pytorch = "*"
dask = "*"
numba = "*" # sparse dependency
llvmlite = "*" # sparse dependency

[tool.pixi.feature.backends.pypi-dependencies]
sparse = { version = ">= 0.16.0b3" }
Expand Down Expand Up @@ -166,19 +167,23 @@ cupy = "*"
# jaxlib = { version = "*", build = "cuda12*" } # unavailable

[tool.pixi.environments]
default = { solve-group = "default" }
lint = { features = ["lint"], solve-group = "default" }
tests = { features = ["tests"], solve-group = "default" }
docs = { features = ["docs"], solve-group = "default" }
dev = { features = ["lint", "tests", "docs", "dev", "backends"], solve-group = "default" }
dev-cuda = { features = ["lint", "tests", "docs", "dev", "backends", "cuda-backends"] }
dev-numpy1 = { features = ["lint", "tests", "dev", "numpy1"] }
default = { features = ["py313"], solve-group = "py313" }
lint = { features = ["py313", "lint"], solve-group = "py313" }
docs = { features = ["py313", "docs"], solve-group = "py313" }
tests = { features = ["py313", "tests"], solve-group = "py313" }
tests-py313 = { features = ["py313", "tests"], solve-group = "py313" } # alias of tests

# Some backends may pin numpy; use separate solve-group
dev = { features = ["py310", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
tests-backends = { features = ["py310", "tests", "backends"], solve-group = "backends" }

# CUDA not available on free github actions and on some developers' PCs
dev-cuda = { features = ["py310", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve-group = "cuda" }

# Ungrouped environments
tests-numpy1 = ["py310", "tests", "numpy1"]
tests-py310 = ["py310", "tests"]
tests-py313 = ["py313", "tests"]
# CUDA not available on free github actions and on some developers' PCs
tests-backends = ["py310", "tests", "backends"]
tests-cuda = ["py310", "tests", "backends", "cuda-backends"]


# pytest
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
if as_numpy:
import numpy as np

arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901 # pyright: ignore[reportInvalidCast]
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901
args_list.append(arg)
assert device is not None

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

# JAX uses `np.testing`
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
np.testing.assert_array_equal(actual, desired, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]


def xp_assert_close(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_copy_invalid():


def test_xp():
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast]
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast]
_ = at(a, 0).set(4, xp=np)
_ = at(a, 0).add(4, xp=np)
_ = at(a, 0).subtract(4, xp=np)
Expand Down
10 changes: 3 additions & 7 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import asarrays, eager_shape, ndindex
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

Expand Down Expand Up @@ -193,7 +193,7 @@ def test_device(self, xp: ModuleType, device: Device):
assert get_device(y) == device

@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
@hypothesis.settings( # pyright: ignore[reportArgumentType]
@hypothesis.settings(
# The xp and library fixtures are not regenerated between hypothesis iterations
suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
# JAX can take a long time to initialize on the first call
Expand Down Expand Up @@ -262,11 +262,7 @@ def f2(*args: Array) -> Array:

ref1 = xp.where(cond, f1(*arrays), fill_value)
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
if library is Backend.ARRAY_API_STRICT:
# FIXME https://github.com/data-apis/array-api-strict/issues/131
ref3 = xp.where(cond, *asarrays(f1(*arrays), float_fill_value, xp=xp))
else:
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
ref3 = xp.where(cond, f1(*arrays), float_fill_value)

xp_assert_close(res1, ref1, rtol=2e-16)
xp_assert_equal(res2, ref2)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def f(x: Array) -> Array:
xp = array_namespace(x)
return xp.sum(x, axis=0) + x

x_np = cast(Array, np.arange(15).reshape(5, 3)) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast]
x_np = cast(Array, np.arange(15).reshape(5, 3)) # type: ignore[bad-cast]
expect = da.asarray(f(x_np))
x_da = da.asarray(x_np).rechunk(3)

Expand Down Expand Up @@ -419,6 +419,6 @@ def f(x: Array) -> Array:
with pytest.raises(ValueError, match="multiple shapes but only one dtype"):
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=np.int32) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
with pytest.raises(ValueError, match="single shape but multiple dtypes"):
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64])
_ = lazy_apply(f, x, shape=(1,), dtype=[np.int32, np.int64]) # pyright: ignore[reportCallIssue,reportArgumentType]
with pytest.raises(ValueError, match="2 shapes and 1 dtypes"):
_ = lazy_apply(f, x, shape=[(1,), (2,)], dtype=[np.int32]) # type: ignore[arg-type] # pyright: ignore[reportCallIssue,reportArgumentType]