@@ -154,23 +154,34 @@ jax = "*"
154
154
[tool .pixi .feature .backends .target .win-64 .dependencies ]
155
155
# jax = "*" # unavailable
156
156
157
- # Backends that require a GPU host and a CUDA driver
158
- # Note: JAX and PyTorch automatically install CUDA variants
159
- # thanks to the `system-requirements` below.
157
+ # Backends that require a GPU host and a CUDA driver.
158
+ # Note that JAX and PyTorch automatically prefer CUDA variants
159
+ # thanks to the `system-requirements` below, *if available*.
160
+ # We request them explicitly below to ensure that we don't
161
+ # quietly revert to CPU-only in the future, e.g. when CUDA 13
162
+ # is released and CUDA 12 builds are dropped upstream.
160
163
[tool .pixi .feature .cuda-backends ]
161
164
system-requirements = { cuda = " 12" }
162
165
163
166
[tool .pixi .feature .cuda-backends .target .linux-64 .dependencies ]
164
167
cupy = " *"
168
+ jaxlib = { version = " *" , build = " cuda12*" }
169
+ pytorch = { version = " *" , build = " cuda12*" }
165
170
166
171
[tool .pixi .feature .cuda-backends .target .osx-64 .dependencies ]
167
172
# cupy = "*" # unavailable
173
+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
174
+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
168
175
169
176
[tool .pixi .feature .cuda-backends .target .osx-arm64 .dependencies ]
170
177
# cupy = "*" # unavailable
178
+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
179
+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
171
180
172
181
[tool .pixi .feature .cuda-backends .target .win-64 .dependencies ]
173
182
cupy = " *"
183
+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
184
+ pytorch = { version = " *" , build = " cuda12*" }
174
185
175
186
[tool .pixi .environments ]
176
187
default = { features = [" py313" ], solve-group = " py313" }
0 commit comments