Skip to content

Commit 7dd3a54

Browse files
crusaderkyNeilGirdhar
authored andcommitted
DEV: pin CUDA variant for PyTorch (data-apis#186)
* DEV: remove redundant CUDA pins for JAX * Revert JAX and make torch explicit too
1 parent 35d9cf1 commit 7dd3a54

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+11-1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ python = "~=3.13.0"
132132
numpy = "=1.22.0"
133133

134134
# Backends that can run on CPU-only hosts
135+
# Note: JAX and PyTorch will install CPU variants.
135136
[tool.pixi.feature.backends.dependencies]
136137
pytorch = "*"
137138
dask = "*"
@@ -153,25 +154,34 @@ jax = "*"
153154
[tool.pixi.feature.backends.target.win-64.dependencies]
154155
# jax = "*" # unavailable
155156

156-
# Backends that require a GPU host and a CUDA driver
157+
# Backends that require a GPU host and a CUDA driver.
158+
# Note that JAX and PyTorch automatically prefer CUDA variants
159+
# thanks to the `system-requirements` below, *if available*.
160+
# We request them explicitly below to ensure that we don't
161+
# quietly revert to CPU-only in the future, e.g. when CUDA 13
162+
# is released and CUDA 12 builds are dropped upstream.
157163
[tool.pixi.feature.cuda-backends]
158164
system-requirements = { cuda = "12" }
159165

160166
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
161167
cupy = "*"
162168
jaxlib = { version = "*", build = "cuda12*" }
169+
pytorch = { version = "*", build = "cuda12*" }
163170

164171
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
165172
# cupy = "*" # unavailable
166173
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
174+
# pytorch = { version = "*", build = "cuda12*" } # unavailable
167175

168176
[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies]
169177
# cupy = "*" # unavailable
170178
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
179+
# pytorch = { version = "*", build = "cuda12*" } # unavailable
171180

172181
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
173182
cupy = "*"
174183
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
184+
pytorch = { version = "*", build = "cuda12*" }
175185

176186
[tool.pixi.environments]
177187
default = { features = ["py313"], solve-group = "py313" }

0 commit comments

Comments
 (0)