Skip to content

Commit f4a612c

Browse files
authored
Training example for llama3; misc changes to make the training work. (#7194)
1 parent 56ddd5d commit f4a612c

File tree

14 files changed

+1748
-9
lines changed

14 files changed

+1748
-9
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Dynamo backend for torchxla2
2+
3+
## Goal
4+
5+
Have a dynamo backend backend by torch_xla2.
6+
7+
The users should be able to do the following:
8+
9+
```python
10+
m = model ...
11+
m_compiled = torch.compile(m, backend='torch_xla2_compile') # backend name TBD
12+
result = m_compiled(*inputs)
13+
```
14+
15+
The above should run on TPU will low overhead.
16+
17+
## Challenge
18+
19+
Usually the challenge of a dynamo backend is the compiler that
20+
transforms a fx graph with torch (or Aten) ops to the compiled executable.
21+
However, in our case, that piece is solved.
22+
23+
For every `call_function` node; we lookup the corresponding implementation of
24+
said ATen op in a dictionary for it's corresponding implementation in Jax,
25+
and we just call it.
26+
27+
This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23
28+
29+
Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n
30+
not incur any data copies in this process.
31+
32+
33+
Consider this following pseudocode:
34+
35+
```python
36+
class XLATensor2:
37+
_data: jax.Array
38+
def __torch_dispatch__(...):
39+
# do stuff with _data, get new data
40+
return XLATensor2(new_data)
41+
42+
def dynamo_backend(fx, sample):
43+
compiled = compile fx into graph that manipulate jax.Array.
44+
def returned_callable(inputs):
45+
datas = [i._data for i in inputs]
46+
res = compiled(*datas)
47+
return TensorSubclass(res)
48+
return returned_callable
49+
50+
model = torch.compile(model, backend = dynamo_backend)
51+
inputs = a list of TensorSubclass or a list of torch.Tensor?
52+
model(*inputs)
53+
```
54+
55+
What would be the type of inputs?
56+
If inputs are of type `TensorSubclass`, then dynamo
57+
will attempt to trace through the `__torch_dispatch__` method,
58+
and throws error because it doesn't know what is `_data` and the
59+
operations on it.
60+
61+
If `inputs` is of type `torch.Tensor`, then it works: dynamo
62+
calls the backend, the backend can produce correct result.
63+
But, `inputs` need to be converted to `TensorSubclass` first inside of
64+
the backend; which usually means a data copy. This happens everytime
65+
the compiled backend is executed, therefore not desirable.
66+
67+
## The Desired behavior
68+
69+
When *tracing* dynamo treats TensorSubclass as if it is a regular tensor
70+
without dispatch override; and when executing the compiled callable,
71+
TensorSubclass is passed in as-is. We know that dynamo can do this with
72+
some tensor subclass, namely `FakeTensor`.
73+
74+
75+
Let's list out the possible ways we could accomplish this behavior.
76+
77+
78+
# Option 1. Have the jax.Array object hold in C++
79+
80+
Roughly we would have a `Tensor` subclass in C++, this is very
81+
similar to the `LazyTensor` subclass that is the current `XLATensor`.
82+
This tensor can hold it's own states in C++. In our case, that would
83+
be a `PyObject*` that happens to point to either `jnp.ndarray` or
84+
jax's `Traced<ShapedArray>` during jax.jit. We might further result the
85+
`XLA` dispatch key to route the operators to the jax implementation,
86+
emulating what `__torch_dispatch__` does.
87+
88+
This way, eager mode will continue to work, and dynamo would work
89+
because the Python class is still `torch.Tensor` (not a subclass), and
90+
there are no Python logic in dispatching so dynamo cannot trace through.
91+
92+
## Pros:
93+
* Very clear that this will work.
94+
* Recommended by ezyang
95+
96+
## Cons:
97+
Now need to deal with C++ builds. In particular, `torch` becomes a source
98+
dependency instead of a pip dependency; meaning, again we need to start
99+
building torch first then build torch_xla2. This might be mitigated if
100+
that subclass can be upstreamed.
101+
102+
103+
# Option 2. Modify dynamo to do the desired behavior
104+
105+
We have one instance where a `torch.Tensor` dispatch subclass
106+
just works with dynamo, without dynamo make a fuss when it traces
107+
`__torch_dispatch__`. This is `FakeTensor`. (https://github.com/pytorch/pytorch/pull/100017/files)
108+
109+
The idea is to make dynamo trace as-if the inputs are `FakeTensor` and
110+
not `XLATensor`. and only after the creation of fx graph and backend, dynamo
111+
calls the compiled callable with `XLATensor`.
112+
113+
Pros:
114+
* Likely pure python changes.
115+
116+
Cons:
117+
* We also need to design a mechanism to represent tensor subclasses that
118+
is desirable for dynamo to trace through, and those is not.
119+
* Likely significant amount of work.
120+
121+
122+
# Option 3. Register All the ops as custom_ops
123+
124+
So currently dynamo traces `__torch_dispatch__`, and we don't like that
125+
because it will find the operations on Jax arrays, and doesn't understand those.
126+
127+
What if we make dynamo **able** to understand what is inside?
128+
The [Black box python functions](https://docs.google.com/document/d/1ZuCVyMfibExwvtzhd9cfMWk5zXT3Dhy1b3kuvAIkBoU/edit#heading=h.56tggsazyrkh) doc
129+
points the possibility of registering things that we don't want dynamo
130+
to go into as a custom op. So we could, theoretically do the following:
131+
132+
1. Register the jax impl of an Aten op as a custom op.
133+
i.e. register `jaten.add` for `aten.add`.
134+
2. For meta kernels, just call the meta kernel of `aten.add`.
135+
3. In `__torch_dispatch__`, we forward the call from `aten.add` to `jaten.add`.
136+
137+
When dynamo attempts to go inside of `__torch_dispatch__`, it will find
138+
`jaten.add`. Then it will record that in the `fx.Graph`.
139+
140+
Our backend will see the same ops but in a different namespace (`jaten`).
141+
That is fine as long as we know how to look up its implementation.
142+
143+
Note: we probably also need to hook up gradients of custom ops via. `autograph.Function`.
144+
145+
146+
Pros / Cons:
147+
Haven't tried, don't know if it gonna work or not.
148+
149+
150+
151+
152+
153+
154+
# Appendix, Failed attempts:
155+
156+
## Attempt 1: move dispatch to a mode (i.e. subclass have no dispatch override)
157+
158+
```python
159+
class Subclass(torch.Tensor):
160+
161+
@staticmethod
162+
def __new__(cls, elem):
163+
dtype = tensor.j2t_dtype(elem.dtype)
164+
shape = list(elem.shape)
165+
for i, s in enumerate(shape):
166+
if not isinstance(s, int):
167+
shape[i] = 1
168+
if dtype is None:
169+
dtype = torch.float32
170+
171+
self = torch.Tensor._make_wrapper_subclass(
172+
cls,
173+
shape,
174+
dtype=dtype,
175+
device='meta',
176+
requires_grad=False,
177+
)
178+
self._meta = torch.empty(
179+
shape, dtype=dtype, device='meta', requires_grad=False
180+
)
181+
self._elem = elem
182+
return self
183+
184+
def __init__(self, elem: jax.Array):
185+
super().__init__()
186+
self._elem = elem
187+
188+
def __str__(self):
189+
return "Subclass({} {})".format(str(type(self._elem)), str(self._elem))
190+
191+
```
192+
193+
This fails with an error saying that exhausted subclasses and all the `__torch_dispatch__` returned `NotImplemented`.
194+

experimental/torch_xla2/examples/__init__.py

Whitespace-only changes.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import jax.numpy as jnp
2+
import jax
3+
from jax.experimental.pallas.ops.tpu import flash_attention
4+
5+
import torch_xla2
6+
from jax.experimental import mesh_utils
7+
from torch_xla2.ops.jtorch import _tpu_flash_attention
8+
9+
env = torch_xla2.default_env()
10+
jax.config.update('jax_enable_x64', False)
11+
env._mesh = jax.sharding.Mesh(
12+
mesh_utils.create_device_mesh((4, )),
13+
axis_names=("fsdp", ),
14+
)
15+
env.use_flash_attention = True
16+
17+
18+
from torch.nn import functional as F
19+
20+
21+
def attn(q, k, v):
22+
q, k, v = env.j2t_iso((q, k, v))
23+
with env:
24+
x = F.scaled_dot_product_attention(q, k, v, is_causal=True)
25+
x = env.t2j_iso(x)
26+
return jnp.sum(x)
27+
28+
29+
import torch
30+
31+
class M(torch.nn.Module):
32+
33+
def __init__(self):
34+
super().__init__()
35+
self.a = torch.nn.Linear(10, 10)
36+
37+
def forward(self, x):
38+
return self.a(x)
39+
40+
m = M()
41+
from torch_xla2.interop import JittableModule
42+
43+
mjit = JittableModule(m)
44+
45+
from torch.nn.utils import stateless
46+
47+
def f(weights, x):
48+
res = mjit.functional_call('forward', weights, {}, (x, ))
49+
return torch.sum(res)
50+
51+
52+
def crossent(x, y):
53+
x, y = env.j2t_iso((x, y))
54+
res = torch.func.functional_call(m, x, (y, ))
55+
return env.t2j_iso(res)
56+
57+
graded = jax.value_and_grad(attn)
58+
59+
shape = (4, 32, 128, 32)
60+
q = jnp.ones(shape, dtype='bfloat16')
61+
v = jnp.ones(shape, dtype='bfloat16')
62+
k = jnp.ones(shape, dtype='bfloat16')
63+
64+
65+
env = torch_xla2.default_env()
66+
weights = env.t2j_iso(env.to_xla(mjit.params))
67+
68+
from torch_xla2.interop import jax_view
69+
70+
#print(jax.jit(graded).lower(q, v, k).as_text())
71+
print(jax.jit(jax.grad(jax_view(f))).lower(
72+
weights, jax.ShapeDtypeStruct((10, ), 'float32')
73+
).as_text())

0 commit comments

Comments
 (0)