Skip to content

Commit bb5c6a0

Browse files
authored
Include torchax in torch_xla (#8895)
1 parent 6400e16 commit bb5c6a0

13 files changed

+121
-72
lines changed

.circleci/common.sh

+3-11
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,10 @@ function install_pre_deps_pytorch_xla() {
106106
}
107107

108108

109+
# TODO(https://github.com/pytorch/xla/issues/8934): Remove PyTorch usage of this function, then
110+
# remove this function from the script.
109111
function install_post_deps_pytorch_xla() {
110-
# Install dependencies after we built torch_xla. This is due to installing
111-
# those packages can potentially trigger `pip install torch_xla` if torch_xla
112-
# is not detected in the system.
113-
114-
# Install JAX dependency since a few tests depend on it.
115-
pip install 'torch_xla[pallas]' \
116-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
117-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
118-
119-
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
120-
pip install xla/torchax
112+
true
121113
}
122114

123115
function build_torch_xla() {

.github/workflows/_test.yml

-7
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,6 @@ jobs:
128128
set -x
129129
130130
pip install expecttest unittest-xml-reporting
131-
pip install torch_xla[pallas] \
132-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
133-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
134-
135-
# Install torchax
136-
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
137-
pip install pytorch/xla/torchax
138131
139132
if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
140133
pip install -r pytorch/xla/benchmarks/requirements.txt

.github/workflows/_test_requiring_torch_cuda.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ jobs:
9191
shell: bash
9292
run: |
9393
set -x
94-
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
95-
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
96-
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
94+
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9795
if: ${{ matrix.run_triton_tests }}
9896
- name: Install Triton
9997
shell: bash

.github/workflows/_tpu_ci.yml

+1-6
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,9 @@ jobs:
2323
pip install --upgrade pip
2424
pip install fsspec
2525
pip install rich
26-
# Jax nightly is needed for pallas tests.
27-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
26+
# libtpu is needed for pallas tests.
2827
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
2928
pip install --upgrade protobuf
30-
31-
# torchax is needed for call_jax tests.
32-
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
33-
pip install pytorch/xla/torchax
3429
- name: Run Tests
3530
env:
3631
PJRT_DEVICE: TPU

.github/workflows/build_and_test.yml

-6
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,10 @@ on:
44
branches:
55
- master
66
- r[0-9]+.[0-9]+
7-
paths-ignore:
8-
- 'experimental/**'
9-
- 'torchax/**'
107
push:
118
branches:
129
- master
1310
- r[0-9]+.[0-9]+
14-
paths-ignore:
15-
- 'experimental/**'
16-
- 'torchax/**'
1711
workflow_dispatch:
1812

1913
concurrency:

.github/workflows/torch_xla2.yml

-4
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@ on:
33
branches:
44
- master
55
- r[0-9]+.[0-9]+
6-
paths:
7-
- 'torchax/**'
86
push:
97
branches:
108
- master
119
- r[0-9]+.[0-9]+
12-
paths:
13-
- 'torchax/**'
1410
workflow_dispatch:
1511

1612
concurrency:

CONTRIBUTING.md

-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ using either VS Code or a local container:
8787
pip install torch_xla[tpu] \
8888
-f https://storage.googleapis.com/libtpu-wheels/index.html \
8989
-f https://storage.googleapis.com/libtpu-releases/index.html
90-
# Optional: if you're using custom kernels, install pallas dependencies
91-
pip install torch_xla[pallas] \
92-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
93-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
9490
```
9591

9692
* If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and

README.md

-10
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ To install PyTorch/XLA stable build in a new TPU VM:
2929
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
3030
-f https://storage.googleapis.com/libtpu-releases/index.html \
3131
-f https://storage.googleapis.com/libtpu-wheels/index.html
32-
33-
# Optional: if you're using custom kernels, install pallas dependencies
34-
pip install 'torch_xla[pallas]' \
35-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
36-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
3732
```
3833

3934
To install PyTorch/XLA nightly build in a new TPU VM:
@@ -43,11 +38,6 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
4338
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
4439
-f https://storage.googleapis.com/libtpu-releases/index.html \
4540
-f https://storage.googleapis.com/libtpu-wheels/index.html
46-
47-
# Optional: if you're using custom kernels, install pallas dependencies
48-
pip install 'torch_xla[pallas]' \
49-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
50-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
5141
```
5242

5343
### C++11 ABI builds

scripts/build_developer.sh

+5-9
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ cd ..
2121
if [ -d "vision" ]; then
2222
cd vision
2323
python3 setup.py develop
24-
cd ..
24+
cd ..
2525
fi
2626

2727
# Install torch_xla
2828
cd pytorch/xla
29-
pip uninstall torch_xla -y
29+
pip uninstall torch_xla torchax torch_xla2 -y
3030

31-
# Optional: build the wheel.
31+
# Build the wheel too, which is useful for other testing purposes.
3232
python3 setup.py bdist_wheel
3333

34+
# Link the source files for local development.
3435
python3 setup.py develop
3536

3637
# libtpu is needed to talk to the TPUs. If TPUs are not present,
@@ -39,10 +40,5 @@ pip install torch_xla[tpu] \
3940
-f https://storage.googleapis.com/libtpu-wheels/index.html \
4041
-f https://storage.googleapis.com/libtpu-releases/index.html
4142

42-
# Install Pallas dependencies
43-
pip install torch_xla[pallas] \
44-
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
45-
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
46-
4743
# Test that the library is installed correctly.
48-
python3 -c 'import torch_xla as xla; print(xla.device())'
44+
python3 -c 'import torch_xla; print(torch_xla.devices()); import torchax; torchax.enable_globally()'

setup.py

+102-6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import distutils.ccompiler
5454
import distutils.command.clean
5555
import os
56+
import re
5657
import requests
5758
import shutil
5859
import subprocess
@@ -226,7 +227,7 @@ class BuildBazelExtension(build_ext.build_ext):
226227
def run(self):
227228
for ext in self.extensions:
228229
self.bazel_build(ext)
229-
command.build_ext.build_ext.run(self)
230+
command.build_ext.build_ext.run(self) # type: ignore
230231

231232
def bazel_build(self, ext):
232233
if not os.path.exists(self.build_temp):
@@ -260,17 +261,107 @@ def bazel_build(self, ext):
260261
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
261262

262263

264+
# Read in README.md for our long_description
265+
cwd = os.path.dirname(os.path.abspath(__file__))
266+
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
267+
long_description = f.read()
268+
269+
# Finds torch_xla and its subpackages
270+
packages_to_include = find_packages(include=['torch_xla*'])
271+
# Explicitly add torchax
272+
packages_to_include.extend(find_packages(where='torchax', include=['torchax*']))
273+
274+
# Map the top-level 'torchax' package name to its source location
275+
torchax_dir = os.path.join(cwd, 'torchax')
276+
package_dir_mapping = {'torch_xla': os.path.join(cwd, 'torch_xla')}
277+
package_dir_mapping['torchax'] = os.path.join(torchax_dir, 'torchax')
278+
279+
263280
class Develop(develop.develop):
264281

265282
def run(self):
283+
# Build the C++ extension
266284
self.run_command("build_ext")
285+
286+
# Run the standard develop process first
287+
# This installs dependencies, scripts, and importantly, creates an `.egg-link` file
267288
super().run()
268289

290+
# Replace the `.egg-link` with a `.pth` file.
291+
self.link_packages()
292+
293+
def link_packages(self):
294+
"""
295+
There are two mechanisms to install an "editable" package in Python: `.egg-link`
296+
and `.pth` files. setuptools uses `.egg-link` by default. However, `.egg-link`
297+
only supports linking a single directory containg one editable package.
298+
This function removes the `.egg-link` file and generates a `.pth` file that can
299+
be used to link multiple packages, in particular, `torch_xla` and `torchax`.
300+
301+
Note that this function is only relevant in the editable package development path
302+
(`python setup.py develop`). Nightly and release wheel builds work out of the box
303+
without egg-link/pth.
304+
"""
305+
# Ensure paths like self.install_dir are set
306+
self.ensure_finalized()
307+
308+
# Get the site-packages directory
309+
target_dir = self.install_dir
310+
311+
# Remove the standard .egg-link file
312+
# It's usually named based on the distribution name
313+
dist_name = self.distribution.get_name()
314+
egg_link_file = os.path.join(target_dir, dist_name + '.egg-link')
315+
if os.path.exists(egg_link_file):
316+
print(f"Removing default egg-link file: {egg_link_file}")
317+
try:
318+
os.remove(egg_link_file)
319+
except OSError as e:
320+
print(f"Warning: Could not remove {egg_link_file}: {e}")
321+
322+
# Create our custom .pth file with specific paths
323+
cwd = os.path.dirname(__file__)
324+
# Path containing 'torch_xla' package source: ROOT
325+
path_for_torch_xla = os.path.abspath(cwd)
326+
# Path containing 'torchax' package source: ROOT/torchax
327+
path_for_torchax = os.path.abspath(os.path.join(cwd, 'torchax'))
328+
329+
paths_to_add = {path_for_torch_xla, path_for_torchax}
330+
331+
# Construct a suitable .pth filename (PEP 660 style is good practice)
332+
version = self.distribution.get_version()
333+
# Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
334+
sanitized_name = re.sub(r"[^a-zA-Z0-9.]+", "_", dist_name)
335+
sanitized_version = re.sub(r"[^a-zA-Z0-9.]+", "_", version)
336+
pth_filename = os.path.join(
337+
target_dir, f"__editable_{sanitized_name}_{sanitized_version}.pth")
338+
339+
# Ensure site-packages exists
340+
os.makedirs(target_dir, exist_ok=True)
341+
342+
# Write the paths to the .pth file, one per line
343+
with open(pth_filename, "w", encoding='utf-8') as f:
344+
for path in sorted(paths_to_add):
345+
f.write(path + "\n")
346+
347+
348+
def _get_jax_install_requirements():
349+
if not USE_NIGHTLY:
350+
# Stable versions of JAX can be directly installed from PyPI.
351+
return [
352+
f'jaxlib=={_jaxlib_version}',
353+
f'jax=={_jax_version}',
354+
]
355+
356+
# Install nightly JAX libraries from the JAX package registries.
357+
jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{_jax_version}-py3-none-any.whl'
358+
jaxlib = []
359+
for python_minor_version in [9, 10, 11]:
360+
jaxlib.append(
361+
f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
362+
)
363+
return [jax] + jaxlib
269364

270-
# Read in README.md for our long_description
271-
cwd = os.path.dirname(os.path.abspath(__file__))
272-
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
273-
long_description = f.read()
274365

275366
setup(
276367
name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'),
@@ -297,7 +388,8 @@ def run(self):
297388
"Programming Language :: Python :: 3",
298389
],
299390
python_requires=">=3.8.0",
300-
packages=find_packages(include=['torch_xla*']),
391+
packages=packages_to_include,
392+
package_dir=package_dir_mapping,
301393
ext_modules=[
302394
BazelExtension('//:_XLAC.so'),
303395
BazelExtension('//:_XLAC_cuda_functions.so'),
@@ -310,6 +402,8 @@ def run(self):
310402
# importlib.metadata backport required for PJRT plugin discovery prior
311403
# to Python 3.10
312404
'importlib_metadata>=4.6;python_version<"3.10"',
405+
# Some torch operations are lowered to HLO via JAX.
406+
*_get_jax_install_requirements(),
313407
],
314408
package_data={
315409
'torch_xla': ['lib/*.so*',],
@@ -331,6 +425,8 @@ def run(self):
331425
f'libtpu=={_libtpu_version}',
332426
'tpu-info',
333427
],
428+
# As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla.
429+
# However, this no-op extras_require entrypoint is left here for backwards compatibility.
334430
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
335431
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
336432
},

test/tpu/xla_test_job.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ spec:
4343
- |
4444
pip install expecttest==0.1.6
4545
pip install rich
46-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
4746
4847
cd /src/pytorch/xla
4948
volumeMounts:

torchax/README.md

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# torchax: Running PyTorch on TPU
22

3-
**torchax!** is a backend for PyTorch, allowing users to run
4-
PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
3+
**torchax** is a backend for PyTorch, allowing users to run
4+
PyTorch on Google CloudTPUs. **torchax** is also a library for providing
55
graph-level interoperability between PyTorch and Jax.
66

77
This means, with **torchax** you can:
@@ -133,7 +133,7 @@ Then, a `jax` device will be available to use
133133

134134
```python
135135
inputs = torch.randn(3, 3, 28, 28, device='jax')
136-
m = MyModel()
136+
m = MyModel().to('jax')
137137
res = m(inputs)
138138
print(type(res)) # outputs torchax.tensor.Tensor
139139
```
@@ -220,13 +220,15 @@ then the subsequent computation with inputs of same shape will be fast.
220220

221221
# Citation:
222222

223+
```
223224
@software{torchax,
224225
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
225226
title = {torchax: PyTorch on TPU and Jax interoperability},
226227
url = {https://github.com/pytorch/xla/tree/master/torchax}
227228
version = {0.0.4},
228229
date = {2025-02-24},
229230
}
231+
```
230232

231233
# Maintainers & Contributors:
232234

@@ -238,6 +240,7 @@ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/goog
238240

239241
Here is the full list of contributors by 2025-02-25.
240242

243+
```
241244
Han Qi (qihqi), Pytorch / XLA
242245
Manfei Bai (manfeibai), Pytorch / XLA
243246
Will Cromar (will-cromar), Meta
@@ -273,4 +276,5 @@ Tianqi Fan (tqfan28), Google(20%)
273276
Jim Lin (jimlinntu), Google(20%)
274277
Fanhai Lu (FanhaiLu1), Google Cloud
275278
DeWitt Clinton (dewitt), Google PyTorch
276-
Aman Gupta (aman2930) , Google(20%)
279+
Aman Gupta (aman2930) , Google(20%)
280+
```

torchax/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ path = "torchax/__init__.py"
4141

4242
[project.optional-dependencies]
4343
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]"]
44-
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
44+
# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html`
4545
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]"]
4646
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]"]
4747
odml = ["jax[cpu]>=0.4.30", "jax[cpu]"]

0 commit comments

Comments
 (0)