Skip to content

Commit 60001c7

Browse files
qihqilsy323
andauthored
Misc changes to make torchax runnable on GPU. (#8756)
Co-authored-by: Siyuan Liu <[email protected]>
1 parent 2feb0ac commit 60001c7

File tree

6 files changed

+194
-118
lines changed

6 files changed

+194
-118
lines changed

.github/workflows/torch_xla2.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ jobs:
3939
run: |
4040
pip install -r test-requirements.txt
4141
pip install -e .[cpu]
42-
pip install tensorflow-cpu # for TF integrations tests
4342
- name: Run tests
4443
working-directory: torchax
4544
shell: bash
@@ -52,7 +51,7 @@ jobs:
5251
pytest test/test_context.py
5352
pytest test/test_train.py
5453
pytest test/test_mutations.py
55-
pytest test/test_tf_integration.py
54+
# pytest test/test_tf_integration.py # TODO(8770)
5655
pytest test/gemma/test_gemma.py
5756
pytest test/llama/test_llama.py
5857
pytest test/test_core_aten_ops.py

torchax/README.md

Lines changed: 158 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,90 @@
1-
# torchxla2
1+
# torchax: Running PyTorch on TPU
22

3-
## Install
4-
5-
Currently this is only source-installable. Requires Python version >= 3.10.
3+
**torchax!** is a backend for PyTorch, allowing users to run
4+
PyTorch on Google CloudTPUs. **torchax!** is also a library for providing
5+
graph-level interoperability between PyTorch and Jax.
66

7-
### NOTE:
7+
This means, with **torchax** you can:
8+
* Run PyTorch code on TPU with as little as 2 lines of code change.
9+
* Call a jax function from a pytorch function, passing in `jax.Array`s
10+
* Call a pytorch function from a jax function, passing in a `torch.Tensor` subclass.
11+
* Use jax features such as `jax.grad`, `optax` and `GSMPD` to train a Pytorch model.
12+
* Use a Pytorch model as feature extractor and use it with a Jax model.
13+
etc etc.
814

9-
Please don't install torch-xla from instructions in
10-
https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md .
11-
In particular, the following are not needed:
15+
## Install
1216

13-
* There is no need to build pytorch/pytorch from source.
14-
* There is no need to clone pytorch/xla project inside of pytorch/pytorch
15-
git checkout.
1617

18+
### On Google Cloud TPU:
19+
First install torch CPU:
1720

18-
TorchXLA2 and torch-xla have different installation instructions, please follow
19-
the instructions below from scratch (fresh venv / conda environment.)
21+
```bash
22+
pip install torch --index-url https://download.pytorch.org/whl/cpu
23+
```
2024

25+
Then install jax TPU:
2126

22-
### 1. Installing `torchax`
27+
```bash
28+
pip install -U jax[tpu]
29+
```
2330

24-
The following instructions assume you are in the `torchax` directory:
31+
Finally install torchax
2532

26-
```
27-
Fork the repository
28-
$ git clone https://github.com/<github_username>/xla.git
29-
$ cd xla/torchax
33+
```bash
34+
pip install torchax
3035
```
3136

37+
### On GPU machines:
38+
First install torch CPU:
3239

33-
#### 1.0 (recommended) Make a virtualenv / conda env
40+
```bash
41+
pip install torch --index-url https://download.pytorch.org/whl/cpu
42+
```
3443

35-
If you are using VSCode, then [you can create a new environment from
36-
UI](https://code.visualstudio.com/docs/python/environments). Select the
37-
`dev-requirements.txt` when asked to install project dependencies.
44+
Then install jax CUDA:
3845

39-
Otherwise create a new environment from the command line.
46+
```bash
47+
pip install -U jax[cuda12]
48+
```
49+
50+
Finally install torchax
4051

4152
```bash
42-
# Option 1: venv
43-
python -m venv my_venv
44-
source my_venv/bin/activate
53+
pip install torchax
54+
```
55+
56+
### On CPU machines (mac included)
57+
First install torch CPU:
4558

46-
# Option 2: conda
47-
conda create --name <your_name> python=3.10
48-
conda activate <your_name>
59+
```bash
60+
# Linux
61+
pip install torch --index-url https://download.pytorch.org/whl/cpu
4962

50-
# Either way, install the dev requirements.
51-
pip install -r dev-requirements.txt
63+
# OR Mac:
64+
pip install torch
5265
```
5366

54-
Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.
67+
Then install jax CPU:
68+
69+
```bash
70+
pip install -U jax
71+
```
5572

56-
#### 1.1 Install this package
73+
Finally install torchax
5774

58-
Install `torchax` from source for your platform:
5975
```bash
60-
pip install -e .[cpu]
61-
pip install -e .[cuda]
62-
pip install -e .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
76+
pip install torchax
6377
```
6478

65-
#### 1.2 (optional) verify installation by running tests
79+
NOTE: if you like metal support for Apple devices then install the
80+
metal version of jax: https://developer.apple.com/metal/jax/
81+
82+
### Installing `torchax` from source
83+
84+
Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
6685

6786
```bash
68-
pip install -r test-requirements.txt
69-
pytest test
87+
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
7088
```
7189

7290
## Run a model
@@ -104,75 +122,47 @@ print(m(inputs))
104122
This model `m` contains 2 parts: the weights that is stored inside of the model
105123
and it's submodules (`nn.Linear`).
106124

107-
To execute this model with `torchax`; we need construct and run the model
108-
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
109-
110-
To create this environment: use
125+
To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
126+
To enable this, use:
111127

112128
```python
113129
import torchax
114-
115-
env = torchax.default_env()
130+
torchax.enable_globally()
116131
```
117-
Then, execute the instantiation of the model, as well as evaluation of model,
118-
using `env` as a context manager:
132+
Then, a `jax` device will be available to use
119133

120134
```python
121-
with env:
122-
inputs = torch.randn(3, 3, 28, 28)
123-
m = MyModel()
124-
res = m(inputs)
125-
print(type(res)) # outputs Tensor
135+
inputs = torch.randn(3, 3, 28, 28, device='jax')
136+
m = MyModel()
137+
res = m(inputs)
138+
print(type(res)) # outputs torchax.tensor.Tensor
126139
```
127140

128-
You can also enable the environment globally with
129-
```python
130-
import torchax
131-
132-
torchax.enable_globally()
133-
```
141+
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
142+
a `jax.Array`. You can inspect that jax array with `res.jax()`
134143

135-
Then everything afterwards is run with XLA.
136144

137145
## What is happening behind the scene:
138146

139-
When a torch op is executed inside of `env` context manager, we can swap out the
140-
implementation of that op with a version that runs on TPU.
141-
When a model's constructor runs, it will call some tensor constructor, such as
142-
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
143-
ops are captured by `env` too and placed directly on TPU.
144-
145-
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
146-
147-
### What if I created model outside of `env`.
148-
149-
So if you have
150-
151-
```
152-
m = MyModel()
153-
```
154-
outside of env, then regular torch ops will run when creating this model.
155-
Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
147+
We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
156148

157-
To move this model into XLA device, one can use `env.to_xla()` function.
149+
In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
150+
implementation of that op written in Jax.
158151

159-
i.e.
160-
```
161-
m2 = env.to_xla(m)
162-
inputs = env.to_xla(inputs)
152+
When a model's constructor runs, it will call some tensor constructor, such as
153+
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
154+
will create an `torch.Tensor` subclass that contains a `jax.Array`.
163155

164-
with env:
165-
res = m2(inputs)
166-
```
156+
Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
157+
and wraps it back into `torch.Tensor` subclass.
167158

168-
NOTE that we also need to move inputs to xla using `.to_xla`.
169-
`to_xla` works with all pytrees of `torch.Tensor`.
159+
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
170160

171161

172162
### Executing with jax.jit
173163

174-
The above script will execute the model using eager mode Jax as backend. This
175-
does allow executing torch models on TPU, but is often slower than what we can
164+
The above script will execute the model using eager mode Jax as backend. This
165+
does allow executing torch models on TPU, but is often slower than what we can
176166
achieve with `jax.jit`.
177167

178168
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
@@ -190,9 +180,9 @@ def model_func(param, inputs):
190180
return torch.func.functional_call(m, param, inputs)
191181

192182
```
193-
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
183+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
194184
from PyTorch to replace the model
195-
weights with `param`, then call the model. This is equivalent to:
185+
weights with `param`, then call the model. This is roughly equivalent to:
196186

197187
```python
198188
def model_func(param, inputs):
@@ -208,4 +198,79 @@ model_func_jitted = jax_jit(model_func)
208198
print(model_func_jitted(new_state_dict, inputs))
209199
```
210200

211-
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
201+
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
202+
203+
However, to ease the idiom of creating functional model and calling it with parameters,
204+
we also created the `JittableModule` helper class.
205+
206+
So the above can be written as:
207+
208+
```python
209+
210+
from torchax.interop import JittableModule
211+
212+
m_jitted = JittableModule(m)
213+
res = m_jitted(...)
214+
```
215+
216+
The first time that `m_jitted` is called , it will trigger `jax.jit`
217+
then the subsequent computation with inputs of same shape will be fast.
218+
219+
220+
221+
# Citation:
222+
223+
@software{torchax,
224+
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
225+
title = {torchax: PyTorch on TPU and Jax interoperability},
226+
url = {https://github.com/pytorch/xla/tree/master/torchax}
227+
version = {0.0.4},
228+
date = {2025-02-24},
229+
}
230+
231+
# Maintainers & Contributors:
232+
233+
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
234+
235+
However, it benefitted from many direct and indirect
236+
contributions outside of the team. Many of them done by
237+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule), others by partner teams.
238+
239+
Here is the full list of contributors by 2025-02-25.
240+
241+
Han Qi (qihqi), Pytorch / XLA
242+
Manfei Bai (manfeibai), Pytorch / XLA
243+
Will Cromar (will-cromar), Meta
244+
Milad Mohammadi (miladm), Pytorch / XLA
245+
Siyuan Liu (lsy323), Pytorch / XLA
246+
Bhavya Bahl (bhavya01), Pytorch / XLA
247+
Pei Zhang (zpcore), Pytorch / XLA
248+
Yifei Teng (tengyifei), Pytorch / XLA
249+
Chunnien Chan (chunnienc), Google, ODML
250+
Alban Desmaison (albanD), Meta, Pytorch
251+
Simon Teo (simonteozw), Google(20%)
252+
David Huang (dvhg), Google(20%)
253+
Barni Seetharaman (barney-s), Google(20%)
254+
Anish Karthik (anishfish2) , Google(20%)
255+
Yao Gu (guyao) , Google(20%)
256+
Yenkai Wang (yenkwang) , Google(20%)
257+
Greg Shikhman (commander) , Google(20%)
258+
Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
259+
Tracy Chen (tracych477), Google(20%)
260+
Matthias Guenther (mrguenther) , Google(20%)
261+
WenXin Dong (wenxindongwork), Google(20%)
262+
Kevin Gleason (GleasonK) , Google, StableHLO
263+
Nupur Baghel (nupurbaghel), Google(20%)
264+
Gwen Mittertreiner (gmittert), Google(20%)
265+
Zeev Melumian (zmelumian), Lightricks
266+
Vyom Sharma (vyom1611), Google(20%)
267+
Shitong Wang (ShitongWang), Adobe
268+
Rémi Doreau (ayshiff), Google(20%)
269+
Lance Wang (wang2yn84), Google, CoreML
270+
Hossein Sarshar (hosseinsarshar) , Google(20%)
271+
Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
272+
Tianqi Fan (tqfan28), Google(20%)
273+
Jim Lin (jimlinntu), Google(20%)
274+
Fanhai Lu (FanhaiLu1), Google Cloud
275+
DeWitt Clinton (dewitt), Google PyTorch
276+
Aman Gupta (aman2930) , Google(20%)

torchax/pyproject.toml

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,46 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "torchax"
7-
dependencies = [
8-
"absl-py",
9-
"immutabledict",
10-
"pytest",
11-
# Developers should install `dev-requirements.txt` first
12-
"torch>=2.3.0",
13-
]
7+
dependencies = []
148
requires-python = ">=3.10"
159
license = {file = "LICENSE"}
1610
dynamic = ["version"]
11+
authors = [
12+
{name = "Han Qi", email = "[email protected]"},
13+
{name = "Pytorch/XLA team", email = "[email protected]"},
14+
]
15+
description = "torchax is a library for running PyTorch on TPU"
16+
readme = "README.md"
17+
classifiers = [
18+
"Development Status :: 3 - Alpha",
19+
"Intended Audience :: Developers",
20+
"Intended Audience :: Education",
21+
"Intended Audience :: Science/Research",
22+
"License :: OSI Approved :: BSD License",
23+
"Topic :: Scientific/Engineering",
24+
"Topic :: Scientific/Engineering :: Mathematics",
25+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
26+
"Topic :: Software Development",
27+
"Topic :: Software Development :: Libraries",
28+
"Topic :: Software Development :: Libraries :: Python Modules",
29+
"Programming Language :: Python :: 3.10",
30+
"Programming Language :: Python :: 3.11",
31+
"Programming Language :: Python :: 3.12",
32+
"Programming Language :: Python :: 3.13",
33+
]
34+
35+
[project.urls]
36+
"Homepage" = "https://github.com/pytorch/xla/tree/master/torchax"
37+
1738

1839
[tool.hatch.version]
1940
path = "torchax/__init__.py"
2041

2142
[project.optional-dependencies]
22-
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]", "tensorflow-cpu"]
43+
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]"]
2344
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
24-
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]", "tensorflow-cpu"]
25-
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]", "tensorflow-cpu"]
45+
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]"]
46+
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]"]
2647
odml = ["jax[cpu]>=0.4.30", "jax[cpu]"]
2748

2849
[tool.hatch.build.targets.wheel]

0 commit comments

Comments
 (0)