File tree 10 files changed +10
-38
lines changed
10 files changed +10
-38
lines changed Original file line number Diff line number Diff line change @@ -110,11 +110,6 @@ function install_post_deps_pytorch_xla() {
110
110
# Install dependencies after we built torch_xla. This is due to installing
111
111
# those packages can potentially trigger `pip install torch_xla` if torch_xla
112
112
# is not detected in the system.
113
-
114
- # Install JAX dependency since a few tests depend on it.
115
- pip install ' torch_xla[pallas]' \
116
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
117
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
118
113
}
119
114
120
115
function build_torch_xla() {
Original file line number Diff line number Diff line change @@ -128,10 +128,7 @@ jobs:
128
128
set -x
129
129
130
130
pip install expecttest unittest-xml-reporting
131
- pip install torch_xla[pallas] \
132
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
133
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
134
-
131
+
135
132
if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
136
133
pip install -r pytorch/xla/benchmarks/requirements.txt
137
134
fi
Original file line number Diff line number Diff line change 23
23
pip install --upgrade pip
24
24
pip install fsspec
25
25
pip install rich
26
- # Jax nightly is needed for pallas tests.
27
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
26
+ # libtpu is needed for pallas tests.
28
27
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
29
28
pip install --upgrade protobuf
30
29
- name : Run Tests
Original file line number Diff line number Diff line change @@ -87,10 +87,6 @@ using either VS Code or a local container:
87
87
pip install torch_xla[tpu] \
88
88
-f https://storage.googleapis.com/libtpu-wheels/index.html \
89
89
-f https://storage.googleapis.com/libtpu-releases/index.html
90
- # Optional: if you're using custom kernels, install pallas dependencies
91
- pip install torch_xla[pallas] \
92
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
93
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
94
90
```
95
91
96
92
* If you are running on a TPU VM, ensure ` torch ` and ` torch_xla ` were built and
Original file line number Diff line number Diff line change @@ -29,11 +29,6 @@ To install PyTorch/XLA stable build in a new TPU VM:
29
29
pip install torch~=2.6.0 ' torch_xla[tpu]~=2.6.0' \
30
30
-f https://storage.googleapis.com/libtpu-releases/index.html \
31
31
-f https://storage.googleapis.com/libtpu-wheels/index.html
32
-
33
- # Optional: if you're using custom kernels, install pallas dependencies
34
- pip install ' torch_xla[pallas]' \
35
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
36
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
37
32
```
38
33
39
34
To install PyTorch/XLA nightly build in a new TPU VM:
@@ -43,11 +38,6 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
43
38
pip install ' torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
44
39
-f https://storage.googleapis.com/libtpu-releases/index.html \
45
40
-f https://storage.googleapis.com/libtpu-wheels/index.html
46
-
47
- # Optional: if you're using custom kernels, install pallas dependencies
48
- pip install ' torch_xla[pallas]' \
49
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
50
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
51
41
```
52
42
53
43
### C++11 ABI builds
Original file line number Diff line number Diff line change @@ -104,9 +104,5 @@ for effective memory management with KV cache.
104
104
## Dependencies
105
105
106
106
The Pallas integration depends on JAX to function. However, not every
107
- JAX version is compatible with your installed PyTorch/XLA. To install
108
- the proper JAX:
109
-
110
- ``` bash
111
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
112
- ```
107
+ JAX version is compatible with your installed PyTorch/XLA. PyTorch/XLA therefore
108
+ pins specific versions of JAX in its setup script.
Original file line number Diff line number Diff line change @@ -40,10 +40,5 @@ pip install torch_xla[tpu] \
40
40
-f https://storage.googleapis.com/libtpu-wheels/index.html \
41
41
-f https://storage.googleapis.com/libtpu-releases/index.html
42
42
43
- # Install Pallas dependencies
44
- pip install torch_xla[pallas] \
45
- -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
46
- -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
47
-
48
43
# Test that the library is installed correctly.
49
44
python3 -c ' import torch_xla; print(torch_xla.devices()); import torchax; torchax.enable_globally()'
Original file line number Diff line number Diff line change @@ -369,6 +369,9 @@ def run(self):
369
369
# importlib.metadata backport required for PJRT plugin discovery prior
370
370
# to Python 3.10
371
371
'importlib_metadata>=4.6;python_version<"3.10"' ,
372
+ # Some torch operations are lowered to HLO via JAX.
373
+ f'jaxlib=={ _jaxlib_version } ' ,
374
+ f'jax=={ _jax_version } ' ,
372
375
],
373
376
package_data = {
374
377
'torch_xla' : ['lib/*.so*' ,],
@@ -390,6 +393,8 @@ def run(self):
390
393
f'libtpu=={ _libtpu_version } ' ,
391
394
'tpu-info' ,
392
395
],
396
+ # As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla.
397
+ # However, this no-op extras_require entrypoint is left here for backwards compatibility.
393
398
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
394
399
'pallas' : [f'jaxlib=={ _jaxlib_version } ' , f'jax=={ _jax_version } ' ],
395
400
},
Original file line number Diff line number Diff line change 43
43
- |
44
44
pip install expecttest==0.1.6
45
45
pip install rich
46
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
47
46
48
47
cd /src/pytorch/xla
49
48
volumeMounts :
Original file line number Diff line number Diff line change @@ -41,7 +41,7 @@ path = "torchax/__init__.py"
41
41
42
42
[project .optional-dependencies ]
43
43
cpu = [" jax[cpu]>=0.4.30" , " jax[cpu]" ]
44
- # Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
44
+ # Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu- releases/index.html`
45
45
tpu = [" jax[cpu]>=0.4.30" , " jax[tpu]" ]
46
46
cuda = [" jax[cpu]>=0.4.30" , " jax[cuda12]" ]
47
47
odml = [" jax[cpu]>=0.4.30" , " jax[cpu]" ]
You can’t perform that action at this time.
0 commit comments