Skip to content

Commit 8eb87ca

Browse files
committed
wip
1 parent 5388c8c commit 8eb87ca

File tree

10 files changed

+10
-38
lines changed

10 files changed

+10
-38
lines changed

.circleci/common.sh

-5
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ function install_post_deps_pytorch_xla() {
110110
# Install dependencies after we built torch_xla. This is due to installing
111111
# those packages can potentially trigger `pip install torch_xla` if torch_xla
112112
# is not detected in the system.
113-
114-
# Install JAX dependency since a few tests depend on it.
115-
pip install 'torch_xla[pallas]' \
116-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
117-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
118113
}
119114

120115
function build_torch_xla() {

.github/workflows/_test.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ jobs:
128128
set -x
129129
130130
pip install expecttest unittest-xml-reporting
131-
pip install torch_xla[pallas] \
132-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
133-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
134-
131+
135132
if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
136133
pip install -r pytorch/xla/benchmarks/requirements.txt
137134
fi

.github/workflows/_tpu_ci.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ jobs:
2323
pip install --upgrade pip
2424
pip install fsspec
2525
pip install rich
26-
# Jax nightly is needed for pallas tests.
27-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
26+
# libtpu is needed for pallas tests.
2827
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
2928
pip install --upgrade protobuf
3029
- name: Run Tests

CONTRIBUTING.md

-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ using either VS Code or a local container:
8787
pip install torch_xla[tpu] \
8888
-f https://storage.googleapis.com/libtpu-wheels/index.html \
8989
-f https://storage.googleapis.com/libtpu-releases/index.html
90-
# Optional: if you're using custom kernels, install pallas dependencies
91-
pip install torch_xla[pallas] \
92-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
93-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
9490
```
9591

9692
* If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and

README.md

-10
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ To install PyTorch/XLA stable build in a new TPU VM:
2929
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
3030
-f https://storage.googleapis.com/libtpu-releases/index.html \
3131
-f https://storage.googleapis.com/libtpu-wheels/index.html
32-
33-
# Optional: if you're using custom kernels, install pallas dependencies
34-
pip install 'torch_xla[pallas]' \
35-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
36-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
3732
```
3833

3934
To install PyTorch/XLA nightly build in a new TPU VM:
@@ -43,11 +38,6 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
4338
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
4439
-f https://storage.googleapis.com/libtpu-releases/index.html \
4540
-f https://storage.googleapis.com/libtpu-wheels/index.html
46-
47-
# Optional: if you're using custom kernels, install pallas dependencies
48-
pip install 'torch_xla[pallas]' \
49-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
50-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
5141
```
5242

5343
### C++11 ABI builds

docs/source/features/pallas.md

+2-6
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,5 @@ for effective memory management with KV cache.
104104
## Dependencies
105105

106106
The Pallas integration depends on JAX to function. However, not every
107-
JAX version is compatible with your installed PyTorch/XLA. To install
108-
the proper JAX:
109-
110-
``` bash
111-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
112-
```
107+
JAX version is compatible with your installed PyTorch/XLA. PyTorch/XLA therefore
108+
pins specific versions of JAX in its setup script.

scripts/build_developer.sh

-5
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,5 @@ pip install torch_xla[tpu] \
4040
-f https://storage.googleapis.com/libtpu-wheels/index.html \
4141
-f https://storage.googleapis.com/libtpu-releases/index.html
4242

43-
# Install Pallas dependencies
44-
pip install torch_xla[pallas] \
45-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
46-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
47-
4843
# Test that the library is installed correctly.
4944
python3 -c 'import torch_xla; print(torch_xla.devices()); import torchax; torchax.enable_globally()'

setup.py

+5
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ def run(self):
369369
# importlib.metadata backport required for PJRT plugin discovery prior
370370
# to Python 3.10
371371
'importlib_metadata>=4.6;python_version<"3.10"',
372+
# Some torch operations are lowered to HLO via JAX.
373+
f'jaxlib=={_jaxlib_version}',
374+
f'jax=={_jax_version}',
372375
],
373376
package_data={
374377
'torch_xla': ['lib/*.so*',],
@@ -390,6 +393,8 @@ def run(self):
390393
f'libtpu=={_libtpu_version}',
391394
'tpu-info',
392395
],
396+
# As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla.
397+
# However, this no-op extras_require entrypoint is left here for backwards compatibility.
393398
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
394399
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
395400
},

test/tpu/xla_test_job.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ spec:
4343
- |
4444
pip install expecttest==0.1.6
4545
pip install rich
46-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
4746
4847
cd /src/pytorch/xla
4948
volumeMounts:

torchax/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ path = "torchax/__init__.py"
4141

4242
[project.optional-dependencies]
4343
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]"]
44-
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
44+
# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html`
4545
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]"]
4646
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]"]
4747
odml = ["jax[cpu]>=0.4.30", "jax[cpu]"]

0 commit comments

Comments
 (0)