Skip to content

Commit 78e71f8

Browse files
committed
Revert JAX and make torch explicit too
1 parent dd5772f commit 78e71f8

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+14-3
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,34 @@ jax = "*"
154154
[tool.pixi.feature.backends.target.win-64.dependencies]
155155
# jax = "*" # unavailable
156156

157-
# Backends that require a GPU host and a CUDA driver
158-
# Note: JAX and PyTorch automatically install CUDA variants
159-
# thanks to the `system-requirements` below.
157+
# Backends that require a GPU host and a CUDA driver.
158+
# Note that JAX and PyTorch automatically prefer CUDA variants
159+
# thanks to the `system-requirements` below, *if available*.
160+
# We request them explicitly below to we ensure that we don't
161+
# quietly revert to CPU-only in the future, e.g. when CUDA 13
162+
# is released and CUDA 12 builds are dropped upstream.
160163
[tool.pixi.feature.cuda-backends]
161164
system-requirements = { cuda = "12" }
162165

163166
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
164167
cupy = "*"
168+
jaxlib = { version = "*", build = "cuda12*" }
169+
pytorch = { version = "*", build = "cuda12*" }
165170

166171
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
167172
# cupy = "*" # unavailable
173+
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
174+
# pytorch = { version = "*", build = "cuda12*" } # unavailable
168175

169176
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
170177
# cupy = "*" # unavailable
178+
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
179+
# pytorch = { version = "*", build = "cuda12*" } # unavailable
171180

172181
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
173182
cupy = "*"
183+
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
184+
pytorch = { version = "*", build = "cuda12*" }
174185

175186
[tool.pixi.environments]
176187
default = { features = ["py313"], solve-group = "py313" }

0 commit comments

Comments
 (0)