Skip to content

Make GPU CUDA plugin require JAX #8919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ jobs:
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Extra CI deps
shell: bash
run: |
set -x
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
if: ${{ matrix.run_triton_tests }}
- name: Install Triton
shell: bash
run: |
Expand Down
125 changes: 124 additions & 1 deletion build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,129 @@
import subprocess
import sys
import shutil
from dataclasses import dataclass
import functools

BASE_DIR = os.path.dirname(os.path.abspath(__file__))


@functools.lru_cache
def get_pinned_packages():
"""Gets the versions of important pinned dependencies of torch_xla."""
return PinnedPackages(
use_nightly=True,
date='20250320',
raw_libtpu_version='0.0.12',
raw_jax_version='0.5.4',
raw_jaxlib_version='0.5.4',
)


@functools.lru_cache
def get_build_version():
xla_git_sha, _torch_git_sha = get_git_head_sha(BASE_DIR)
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
if check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
try:
version += '+git' + xla_git_sha[:7]
except Exception:
pass
return version


@functools.lru_cache
def get_git_head_sha(base_dir):
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=base_dir).decode('ascii').strip()
if os.path.isdir(os.path.join(base_dir, '..', '.git')):
torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=os.path.join(
base_dir,
'..')).decode('ascii').strip()
else:
torch_git_sha = ''
return xla_git_sha, torch_git_sha


def get_jax_cuda_requirements():
"""Get a list of JAX CUDA requirements for use in setup.py without extra package registries."""
pinned_packages = get_pinned_packages()
if not pinned_packages.use_nightly:
# Stable versions of JAX can be directly installed from PyPI.
return [
f'jaxlib=={pinned_packages.jaxlib_version}',
f'jax=={pinned_packages.jax_version}',
f'jax[cuda12]=={pinned_packages.jax_version}',
]

# Install nightly JAX libraries from the JAX package registries.
jax = f'jax @ https://storage.googleapis.com/jax-releases/nightly/jax/jax-{pinned_packages.jax_version}-py3-none-any.whl'
jaxlib = []
for python_minor_version in [9, 10, 11]:
jaxlib.append(
f'jaxlib @ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
)

# Install nightly JAX CUDA libraries.
jax_cuda = [
f'jax-cuda12-plugin @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_pjrt-{pinned_packages.jax_version}-py3-none-manylinux2014_x86_64.whl'
]
for python_minor_version in [9, 10, 11]:
jax_cuda.append(
f'jax-cuda12-pjrt @ https://storage.googleapis.com/jax-releases/nightly/wheels/jax_cuda12_plugin-{pinned_packages.jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
)

return [jax] + jaxlib + jax_cuda


@dataclass(eq=True, frozen=True)
class PinnedPackages:
use_nightly: bool
"""Whether to use nightly or stable libtpu and JAX"""

date: str
raw_libtpu_version: str
raw_jax_version: str
raw_jaxlib_version: str

@property
def libtpu_version(self) -> str:
if self.use_nightly:
return f'{self.raw_libtpu_version}.dev{self.date}'
else:
return self.raw_libtpu_version

@property
def jax_version(self) -> str:
if self.use_nightly:
return f'{self.raw_jax_version}.dev{self.date}'
else:
return self.raw_jax_version

@property
def jaxlib_version(self) -> str:
if self.use_nightly:
return f'{self.raw_jaxlib_version}.dev{self.date}'
else:
return self.raw_jaxlib_version

@property
def libtpu_storage_directory(self) -> str:
if self.use_nightly:
return 'libtpu-nightly-releases'
else:
return 'libtpu-lts-releases'

@property
def libtpu_wheel_name(self) -> str:
if self.use_nightly:
return f'libtpu-{self.libtpu_version}+nightly'
else:
return f'libtpu-{self.libtpu_version}'

@property
def libtpu_storage_path(self) -> str:
return f'https://storage.googleapis.com/{self.libtpu_storage_directory}/wheels/libtpu/{self.libtpu_wheel_name}-py3-none-linux_x86_64.whl'


def check_env_flag(name: str, default: str = '') -> bool:
Expand Down Expand Up @@ -60,7 +183,7 @@ def bazel_build(bazel_target: str,
]

# Remove duplicated flags because they confuse bazel
flags = set(bazel_options_from_env() + options)
flags = set(list(bazel_options_from_env()) + list(options))
bazel_argv.extend(flags)

print(' '.join(bazel_argv), flush=True)
Expand Down
7 changes: 3 additions & 4 deletions plugins/cuda/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import datetime
import os
import sys

Expand All @@ -12,6 +11,6 @@
'torch_xla_cuda_plugin/lib', ['--config=cuda'])

setuptools.setup(
# TODO: Use a common version file
version=os.getenv('TORCH_XLA_VERSION',
f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}'))
version=build_util.get_build_version(),
install_requires=build_util.get_jax_cuda_requirements(),
)
59 changes: 10 additions & 49 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,14 @@
import os
import requests
import shutil
import subprocess
import sys
import tempfile
import zipfile

import build_util

base_dir = os.path.dirname(os.path.abspath(__file__))

USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax

_date = '20250320'
_libtpu_version = '0.0.12'
_jax_version = '0.5.4'
_jaxlib_version = '0.5.4'

_libtpu_wheel_name = f'libtpu-{_libtpu_version}'
_libtpu_storage_directory = 'libtpu-lts-releases'

if USE_NIGHTLY:
_libtpu_version += f".dev{_date}"
_jax_version += f".dev{_date}"
_jaxlib_version += f".dev{_date}"
_libtpu_wheel_name += f".dev{_date}+nightly"
_libtpu_storage_directory = 'libtpu-nightly-releases'

_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}-py3-none-linux_x86_64.whl'
pinned_packages = build_util.get_pinned_packages()


def _get_build_mode():
Expand All @@ -90,29 +71,6 @@ def _get_build_mode():
return sys.argv[i]


def get_git_head_sha(base_dir):
xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=base_dir).decode('ascii').strip()
if os.path.isdir(os.path.join(base_dir, '..', '.git')):
torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=os.path.join(
base_dir,
'..')).decode('ascii').strip()
else:
torch_git_sha = ''
return xla_git_sha, torch_git_sha


def get_build_version(xla_git_sha):
version = os.getenv('TORCH_XLA_VERSION', '2.8.0')
if build_util.check_env_flag('GIT_VERSIONED_XLA_BUILD', default='TRUE'):
try:
version += '+git' + xla_git_sha[:7]
except Exception:
pass
return version


def create_version_files(base_dir, version, xla_git_sha, torch_git_sha):
print('Building torch_xla version: {}'.format(version))
print('XLA Commit ID: {}'.format(xla_git_sha))
Expand Down Expand Up @@ -151,7 +109,7 @@ def maybe_bundle_libtpu(base_dir):
print('No installed libtpu found. Downloading...')

with tempfile.NamedTemporaryFile('wb') as whl:
resp = requests.get(_libtpu_storage_path)
resp = requests.get(pinned_packages.libtpu_storage_path)
resp.raise_for_status()

whl.write(resp.content)
Expand Down Expand Up @@ -194,8 +152,8 @@ def run(self):
distutils.command.clean.clean.run(self)


xla_git_sha, torch_git_sha = get_git_head_sha(base_dir)
version = get_build_version(xla_git_sha)
xla_git_sha, torch_git_sha = build_util.get_git_head_sha(base_dir)
version = build_util.get_build_version()

build_mode = _get_build_mode()
if build_mode not in ['clean']:
Expand Down Expand Up @@ -226,7 +184,7 @@ class BuildBazelExtension(build_ext.build_ext):
def run(self):
for ext in self.extensions:
self.bazel_build(ext)
command.build_ext.build_ext.run(self)
command.build_ext.build_ext.run(self) # type: ignore

def bazel_build(self, ext):
if not os.path.exists(self.build_temp):
Expand Down Expand Up @@ -328,11 +286,14 @@ def run(self):
# On Cloud TPU VM install with:
# pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html
'tpu': [
f'libtpu=={_libtpu_version}',
f'libtpu=={pinned_packages.libtpu_version}',
'tpu-info',
],
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
'pallas': [
f'jaxlib=={pinned_packages.jaxlib_version}',
f'jax=={pinned_packages.jax_version}'
],
},
cmdclass={
'build_ext': BuildBazelExtension,
Expand Down
Loading