From dd5772f4e6cb5186cb3133f8acf26a103525b7d2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 26 Mar 2025 11:47:41 +0000 Subject: [PATCH 1/2] DEV: remove redundant CUDA pins for JAX --- pixi.lock | 2 +- pyproject.toml | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pixi.lock b/pixi.lock index 4bab4907..d5c67694 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5232,7 +5232,7 @@ packages: - pypi: . name: array-api-extra version: 0.7.1.dev0 - sha256: 54fed5ddb6a0d325790f34dd81b796833e3772fcb2b9a81511f26b1ebf32df25 + sha256: ccfb7d0c525ec08f6a1c18835e1a0db1d50294781451b3ae8630817fbb9431f6 requires_dist: - array-api-compat>=1.11,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index 4394b4cb..bf8a780b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ python = "~=3.13.0" numpy = "=1.22.0" # Backends that can run on CPU-only hosts +# Note: JAX and PyTorch will install CPU variants. [tool.pixi.feature.backends.dependencies] pytorch = "*" dask = "*" @@ -154,24 +155,22 @@ jax = "*" # jax = "*" # unavailable # Backends that require a GPU host and a CUDA driver +# Note: JAX and PyTorch automatically install CUDA variants +# thanks to the `system-requirements` below. [tool.pixi.feature.cuda-backends] system-requirements = { cuda = "12" } [tool.pixi.feature.cuda-backends.target.linux-64.dependencies] cupy = "*" -jaxlib = { version = "*", build = "cuda12*" } [tool.pixi.feature.cuda-backends.target.osx-64.dependencies] # cupy = "*" # unavailable -# jaxlib = { version = "*", build = "cuda12*" } # unavailable [tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies] # cupy = "*" # unavailable -# jaxlib = { version = "*", build = "cuda12*" } # unavailable [tool.pixi.feature.cuda-backends.target.win-64.dependencies] cupy = "*" -# jaxlib = { version = "*", build = "cuda12*" } # unavailable [tool.pixi.environments] default = { features = ["py313"], solve-group = "py313" } From a1e28859854f9f8a8ccc69ab12ba7234586f4165 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 26 Mar 2025 13:40:57 +0000 Subject: [PATCH 2/2] Revert JAX and make torch explicit too --- pixi.lock | 2 +- pyproject.toml | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pixi.lock b/pixi.lock index d5c67694..8c0a6b3b 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5232,7 +5232,7 @@ packages: - pypi: . name: array-api-extra version: 0.7.1.dev0 - sha256: ccfb7d0c525ec08f6a1c18835e1a0db1d50294781451b3ae8630817fbb9431f6 + sha256: 676a791c66366ceb58f64f5bff8010d4f3c1077846f7b9c411883b46eb55fd38 requires_dist: - array-api-compat>=1.11,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index bf8a780b..0aabf345 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,23 +154,34 @@ jax = "*" [tool.pixi.feature.backends.target.win-64.dependencies] # jax = "*" # unavailable -# Backends that require a GPU host and a CUDA driver -# Note: JAX and PyTorch automatically install CUDA variants -# thanks to the `system-requirements` below. +# Backends that require a GPU host and a CUDA driver. +# Note that JAX and PyTorch automatically prefer CUDA variants +# thanks to the `system-requirements` below, *if available*. +# We request them explicitly below to ensure that we don't +# quietly revert to CPU-only in the future, e.g. when CUDA 13 +# is released and CUDA 12 builds are dropped upstream. [tool.pixi.feature.cuda-backends] system-requirements = { cuda = "12" } [tool.pixi.feature.cuda-backends.target.linux-64.dependencies] cupy = "*" +jaxlib = { version = "*", build = "cuda12*" } +pytorch = { version = "*", build = "cuda12*" } [tool.pixi.feature.cuda-backends.target.osx-64.dependencies] # cupy = "*" # unavailable +# jaxlib = { version = "*", build = "cuda12*" } # unavailable +# pytorch = { version = "*", build = "cuda12*" } # unavailable [tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies] # cupy = "*" # unavailable +# jaxlib = { version = "*", build = "cuda12*" } # unavailable +# pytorch = { version = "*", build = "cuda12*" } # unavailable [tool.pixi.feature.cuda-backends.target.win-64.dependencies] cupy = "*" +# jaxlib = { version = "*", build = "cuda12*" } # unavailable +pytorch = { version = "*", build = "cuda12*" } [tool.pixi.environments] default = { features = ["py313"], solve-group = "py313" }