Skip to content

Commit dd5772f

Browse files
committed
DEV: remove redundant CUDA pins for JAX
1 parent d28c2a2 commit dd5772f

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

Diff for: pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+3-4
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ python = "~=3.13.0"
132132
numpy = "=1.22.0"
133133

134134
# Backends that can run on CPU-only hosts
135+
# Note: JAX and PyTorch will install CPU variants.
135136
[tool.pixi.feature.backends.dependencies]
136137
pytorch = "*"
137138
dask = "*"
@@ -154,24 +155,22 @@ jax = "*"
154155
# jax = "*" # unavailable
155156

156157
# Backends that require a GPU host and a CUDA driver
158+
# Note: JAX and PyTorch automatically install CUDA variants
159+
# thanks to the `system-requirements` below.
157160
[tool.pixi.feature.cuda-backends]
158161
system-requirements = { cuda = "12" }
159162

160163
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
161164
cupy = "*"
162-
jaxlib = { version = "*", build = "cuda12*" }
163165

164166
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
165167
# cupy = "*" # unavailable
166-
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
167168

168169
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
169170
# cupy = "*" # unavailable
170-
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
171171

172172
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
173173
cupy = "*"
174-
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
175174

176175
[tool.pixi.environments]
177176
default = { features = ["py313"], solve-group = "py313" }

0 commit comments

Comments
 (0)