Skip to content

Include torchax in torch_xla #8895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 4, 2025
14 changes: 3 additions & 11 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,10 @@ function install_pre_deps_pytorch_xla() {
}


# TODO(https://github.com/pytorch/xla/issues/8934): Remove PyTorch usage of this function, then
# remove this function from the script.
function install_post_deps_pytorch_xla() {
# Install dependencies after we built torch_xla. This is due to installing
# those packages can potentially trigger `pip install torch_xla` if torch_xla
# is not detected in the system.

# Install JAX dependency since a few tests depend on it.
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
pip install xla/torchax
true
}

function build_torch_xla() {
Expand Down
7 changes: 0 additions & 7 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,6 @@ jobs:
set -x

pip install expecttest unittest-xml-reporting
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Install torchax
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
pip install pytorch/xla/torchax

if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
pip install -r pytorch/xla/benchmarks/requirements.txt
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ jobs:
shell: bash
run: |
set -x
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
if: ${{ matrix.run_triton_tests }}
- name: Install Triton
shell: bash
Expand Down
7 changes: 1 addition & 6 deletions .github/workflows/_tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,9 @@ jobs:
pip install --upgrade pip
pip install fsspec
pip install rich
# Jax nightly is needed for pallas tests.
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# libtpu is needed for pallas tests.
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
pip install --upgrade protobuf

# torchax is needed for call_jax tests.
# TODO(https://github.com/pytorch/xla/issues/8831): Remove this when torchax is part of torch_xla.
pip install pytorch/xla/torchax
- name: Run Tests
env:
PJRT_DEVICE: TPU
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@ on:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/**'
- 'torchax/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths-ignore:
- 'experimental/**'
- 'torchax/**'
workflow_dispatch:

concurrency:
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@ on:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'torchax/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'torchax/**'
workflow_dispatch:

concurrency:
Expand Down
4 changes: 0 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ using either VS Code or a local container:
pip install torch_xla[tpu] \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

* If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and
Expand Down
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ To install PyTorch/XLA stable build in a new TPU VM:
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html

# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

To install PyTorch/XLA nightly build in a new TPU VM:
Expand All @@ -43,11 +38,6 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html

# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```

### C++11 ABI builds
Expand Down
14 changes: 5 additions & 9 deletions scripts/build_developer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ cd ..
if [ -d "vision" ]; then
cd vision
python3 setup.py develop
cd ..
cd ..
fi

# Install torch_xla
cd pytorch/xla
pip uninstall torch_xla -y
pip uninstall torch_xla torchax torch_xla2 -y

# Optional: build the wheel.
# Build the wheel too, which is useful for other testing purposes.
python3 setup.py bdist_wheel

# Link the source files for local development.
python3 setup.py develop

# libtpu is needed to talk to the TPUs. If TPUs are not present,
Expand All @@ -39,10 +40,5 @@ pip install torch_xla[tpu] \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://storage.googleapis.com/libtpu-releases/index.html

# Install Pallas dependencies
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Test that the library is installed correctly.
python3 -c 'import torch_xla as xla; print(xla.device())'
python3 -c 'import torch_xla; print(torch_xla.devices()); import torchax; torchax.enable_globally()'
108 changes: 102 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import distutils.ccompiler
import distutils.command.clean
import os
import re
import requests
import shutil
import subprocess
Expand Down Expand Up @@ -226,7 +227,7 @@ class BuildBazelExtension(build_ext.build_ext):
def run(self):
for ext in self.extensions:
self.bazel_build(ext)
command.build_ext.build_ext.run(self)
command.build_ext.build_ext.run(self) # type: ignore

def bazel_build(self, ext):
if not os.path.exists(self.build_temp):
Expand Down Expand Up @@ -260,17 +261,107 @@ def bazel_build(self, ext):
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)


# Read in README.md for our long_description
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
long_description = f.read()

# Finds torch_xla and its subpackages
packages_to_include = find_packages(include=['torch_xla*'])
# Explicitly add torchax
packages_to_include.extend(find_packages(where='torchax', include=['torchax*']))

# Map the top-level 'torchax' package name to its source location
torchax_dir = os.path.join(cwd, 'torchax')
package_dir_mapping = {'torch_xla': os.path.join(cwd, 'torch_xla')}
package_dir_mapping['torchax'] = os.path.join(torchax_dir, 'torchax')


class Develop(develop.develop):

def run(self):
# Build the C++ extension
self.run_command("build_ext")

# Run the standard develop process first
# This installs dependencies, scripts, and importantly, creates an `.egg-link` file
super().run()

# Replace the `.egg-link` with a `.pth` file.
self.link_packages()

def link_packages(self):
"""
There are two mechanisms to install an "editable" package in Python: `.egg-link`
and `.pth` files. setuptools uses `.egg-link` by default. However, `.egg-link`
only supports linking a single directory containg one editable package.
This function removes the `.egg-link` file and generates a `.pth` file that can
be used to link multiple packages, in particular, `torch_xla` and `torchax`.

Note that this function is only relevant in the editable package development path
(`python setup.py develop`). Nightly and release wheel builds work out of the box
without egg-link/pth.
"""
# Ensure paths like self.install_dir are set
self.ensure_finalized()

# Get the site-packages directory
target_dir = self.install_dir

# Remove the standard .egg-link file
# It's usually named based on the distribution name
dist_name = self.distribution.get_name()
egg_link_file = os.path.join(target_dir, dist_name + '.egg-link')
if os.path.exists(egg_link_file):
print(f"Removing default egg-link file: {egg_link_file}")
try:
os.remove(egg_link_file)
except OSError as e:
print(f"Warning: Could not remove {egg_link_file}: {e}")

# Create our custom .pth file with specific paths
cwd = os.path.dirname(__file__)
# Path containing 'torch_xla' package source: ROOT
path_for_torch_xla = os.path.abspath(cwd)
# Path containing 'torchax' package source: ROOT/torchax
path_for_torchax = os.path.abspath(os.path.join(cwd, 'torchax'))

paths_to_add = {path_for_torch_xla, path_for_torchax}

# Construct a suitable .pth filename (PEP 660 style is good practice)
version = self.distribution.get_version()
# Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
sanitized_name = re.sub(r"[^a-zA-Z0-9.]+", "_", dist_name)
sanitized_version = re.sub(r"[^a-zA-Z0-9.]+", "_", version)
pth_filename = os.path.join(
target_dir, f"__editable_{sanitized_name}_{sanitized_version}.pth")

# Ensure site-packages exists
os.makedirs(target_dir, exist_ok=True)

# Write the paths to the .pth file, one per line
with open(pth_filename, "w", encoding='utf-8') as f:
for path in sorted(paths_to_add):
f.write(path + "\n")


def _get_jax_install_requirements():
if not USE_NIGHTLY:
# Stable versions of JAX can be directly installed from PyPI.
return [
f'jaxlib=={_jaxlib_version}',
f'jax=={_jax_version}',
]

# Install nightly JAX libraries from the JAX package registries.
jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{_jax_version}-py3-none-any.whl'
jaxlib = []
for python_minor_version in [9, 10, 11]:
jaxlib.append(
f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
)
return [jax] + jaxlib

# Read in README.md for our long_description
cwd = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
long_description = f.read()

setup(
name=os.environ.get('TORCH_XLA_PACKAGE_NAME', 'torch_xla'),
Expand All @@ -297,7 +388,8 @@ def run(self):
"Programming Language :: Python :: 3",
],
python_requires=">=3.8.0",
packages=find_packages(include=['torch_xla*']),
packages=packages_to_include,
package_dir=package_dir_mapping,
ext_modules=[
BazelExtension('//:_XLAC.so'),
BazelExtension('//:_XLAC_cuda_functions.so'),
Expand All @@ -310,6 +402,8 @@ def run(self):
# importlib.metadata backport required for PJRT plugin discovery prior
# to Python 3.10
'importlib_metadata>=4.6;python_version<"3.10"',
# Some torch operations are lowered to HLO via JAX.
*_get_jax_install_requirements(),
],
package_data={
'torch_xla': ['lib/*.so*',],
Expand All @@ -331,6 +425,8 @@ def run(self):
f'libtpu=={_libtpu_version}',
'tpu-info',
],
# As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla.
# However, this no-op extras_require entrypoint is left here for backwards compatibility.
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
},
Expand Down
1 change: 0 additions & 1 deletion test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ spec:
- |
pip install expecttest==0.1.6
pip install rich
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

cd /src/pytorch/xla
volumeMounts:
Expand Down
12 changes: 8 additions & 4 deletions torchax/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# torchax: Running PyTorch on TPU

**torchax!** is a backend for PyTorch, allowing users to run
PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
**torchax** is a backend for PyTorch, allowing users to run
PyTorch on Google CloudTPUs. **torchax** is also a library for providing
graph-level interoperability between PyTorch and Jax.

This means, with **torchax** you can:
Expand Down Expand Up @@ -133,7 +133,7 @@ Then, a `jax` device will be available to use

```python
inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel()
m = MyModel().to('jax')
res = m(inputs)
print(type(res)) # outputs torchax.tensor.Tensor
```
Expand Down Expand Up @@ -220,13 +220,15 @@ then the subsequent computation with inputs of same shape will be fast.

# Citation:

```
@software{torchax,
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
title = {torchax: PyTorch on TPU and Jax interoperability},
url = {https://github.com/pytorch/xla/tree/master/torchax}
version = {0.0.4},
date = {2025-02-24},
}
```

# Maintainers & Contributors:

Expand All @@ -238,6 +240,7 @@ fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/goog

Here is the full list of contributors by 2025-02-25.

```
Han Qi (qihqi), Pytorch / XLA
Manfei Bai (manfeibai), Pytorch / XLA
Will Cromar (will-cromar), Meta
Expand Down Expand Up @@ -273,4 +276,5 @@ Tianqi Fan (tqfan28), Google(20%)
Jim Lin (jimlinntu), Google(20%)
Fanhai Lu (FanhaiLu1), Google Cloud
DeWitt Clinton (dewitt), Google PyTorch
Aman Gupta (aman2930) , Google(20%)
Aman Gupta (aman2930) , Google(20%)
```
2 changes: 1 addition & 1 deletion torchax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ path = "torchax/__init__.py"

[project.optional-dependencies]
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]"]
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html`
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]"]
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]"]
odml = ["jax[cpu]>=0.4.30", "jax[cpu]"]
Expand Down
Loading