|
| 1 | +# Dynamo backend for torchxla2 |
| 2 | + |
| 3 | +## Goal |
| 4 | + |
| 5 | +Have a dynamo backend backend by torch_xla2. |
| 6 | + |
| 7 | +The users should be able to do the following: |
| 8 | + |
| 9 | +```python |
| 10 | +m = model ... |
| 11 | +m_compiled = torch.compile(m, backend='torch_xla2_compile') # backend name TBD |
| 12 | +result = m_compiled(*inputs) |
| 13 | +``` |
| 14 | + |
| 15 | +The above should run on TPU will low overhead. |
| 16 | + |
| 17 | +## Challenge |
| 18 | + |
| 19 | +Usually the challenge of a dynamo backend is the compiler that |
| 20 | +transforms a fx graph with torch (or Aten) ops to the compiled executable. |
| 21 | +However, in our case, that piece is solved. |
| 22 | + |
| 23 | +For every `call_function` node; we lookup the corresponding implementation of |
| 24 | +said ATen op in a dictionary for it's corresponding implementation in Jax, |
| 25 | +and we just call it. |
| 26 | + |
| 27 | +This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23 |
| 28 | + |
| 29 | +Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n |
| 30 | +not incur any data copies in this process. |
| 31 | + |
| 32 | + |
| 33 | +Consider this following pseudocode: |
| 34 | + |
| 35 | +```python |
| 36 | +class XLATensor2: |
| 37 | + _data: jax.Array |
| 38 | + def __torch_dispatch__(...): |
| 39 | + # do stuff with _data, get new data |
| 40 | + return XLATensor2(new_data) |
| 41 | + |
| 42 | +def dynamo_backend(fx, sample): |
| 43 | + compiled = compile fx into graph that manipulate jax.Array. |
| 44 | + def returned_callable(inputs): |
| 45 | + datas = [i._data for i in inputs] |
| 46 | + res = compiled(*datas) |
| 47 | + return TensorSubclass(res) |
| 48 | + return returned_callable |
| 49 | + |
| 50 | +model = torch.compile(model, backend = dynamo_backend) |
| 51 | +inputs = a list of TensorSubclass or a list of torch.Tensor? |
| 52 | +model(*inputs) |
| 53 | +``` |
| 54 | + |
| 55 | +What would be the type of inputs? |
| 56 | +If inputs are of type `TensorSubclass`, then dynamo |
| 57 | +will attempt to trace through the `__torch_dispatch__` method, |
| 58 | +and throws error because it doesn't know what is `_data` and the |
| 59 | +operations on it. |
| 60 | + |
| 61 | +If `inputs` is of type `torch.Tensor`, then it works: dynamo |
| 62 | +calls the backend, the backend can produce correct result. |
| 63 | +But, `inputs` need to be converted to `TensorSubclass` first inside of |
| 64 | +the backend; which usually means a data copy. This happens everytime |
| 65 | +the compiled backend is executed, therefore not desirable. |
| 66 | + |
| 67 | +## The Desired behavior |
| 68 | + |
| 69 | +When *tracing* dynamo treats TensorSubclass as if it is a regular tensor |
| 70 | +without dispatch override; and when executing the compiled callable, |
| 71 | +TensorSubclass is passed in as-is. We know that dynamo can do this with |
| 72 | +some tensor subclass, namely `FakeTensor`. |
| 73 | + |
| 74 | + |
| 75 | +Let's list out the possible ways we could accomplish this behavior. |
| 76 | + |
| 77 | + |
| 78 | +# Option 1. Have the jax.Array object hold in C++ |
| 79 | + |
| 80 | +Roughly we would have a `Tensor` subclass in C++, this is very |
| 81 | +similar to the `LazyTensor` subclass that is the current `XLATensor`. |
| 82 | +This tensor can hold it's own states in C++. In our case, that would |
| 83 | +be a `PyObject*` that happens to point to either `jnp.ndarray` or |
| 84 | +jax's `Traced<ShapedArray>` during jax.jit. We might further result the |
| 85 | +`XLA` dispatch key to route the operators to the jax implementation, |
| 86 | +emulating what `__torch_dispatch__` does. |
| 87 | + |
| 88 | +This way, eager mode will continue to work, and dynamo would work |
| 89 | +because the Python class is still `torch.Tensor` (not a subclass), and |
| 90 | +there are no Python logic in dispatching so dynamo cannot trace through. |
| 91 | + |
| 92 | +## Pros: |
| 93 | +* Very clear that this will work. |
| 94 | +* Recommended by ezyang |
| 95 | + |
| 96 | +## Cons: |
| 97 | +Now need to deal with C++ builds. In particular, `torch` becomes a source |
| 98 | +dependency instead of a pip dependency; meaning, again we need to start |
| 99 | +building torch first then build torch_xla2. This might be mitigated if |
| 100 | +that subclass can be upstreamed. |
| 101 | + |
| 102 | + |
| 103 | +# Option 2. Modify dynamo to do the desired behavior |
| 104 | + |
| 105 | +We have one instance where a `torch.Tensor` dispatch subclass |
| 106 | +just works with dynamo, without dynamo make a fuss when it traces |
| 107 | +`__torch_dispatch__`. This is `FakeTensor`. (https://github.com/pytorch/pytorch/pull/100017/files) |
| 108 | + |
| 109 | +The idea is to make dynamo trace as-if the inputs are `FakeTensor` and |
| 110 | +not `XLATensor`. and only after the creation of fx graph and backend, dynamo |
| 111 | +calls the compiled callable with `XLATensor`. |
| 112 | + |
| 113 | +Pros: |
| 114 | +* Likely pure python changes. |
| 115 | + |
| 116 | +Cons: |
| 117 | +* We also need to design a mechanism to represent tensor subclasses that |
| 118 | + is desirable for dynamo to trace through, and those is not. |
| 119 | +* Likely significant amount of work. |
| 120 | + |
| 121 | + |
| 122 | +# Option 3. Register All the ops as custom_ops |
| 123 | + |
| 124 | +So currently dynamo traces `__torch_dispatch__`, and we don't like that |
| 125 | +because it will find the operations on Jax arrays, and doesn't understand those. |
| 126 | + |
| 127 | +What if we make dynamo **able** to understand what is inside? |
| 128 | +The [Black box python functions](https://docs.google.com/document/d/1ZuCVyMfibExwvtzhd9cfMWk5zXT3Dhy1b3kuvAIkBoU/edit#heading=h.56tggsazyrkh) doc |
| 129 | +points the possibility of registering things that we don't want dynamo |
| 130 | +to go into as a custom op. So we could, theoretically do the following: |
| 131 | + |
| 132 | +1. Register the jax impl of an Aten op as a custom op. |
| 133 | + i.e. register `jaten.add` for `aten.add`. |
| 134 | +2. For meta kernels, just call the meta kernel of `aten.add`. |
| 135 | +3. In `__torch_dispatch__`, we forward the call from `aten.add` to `jaten.add`. |
| 136 | + |
| 137 | +When dynamo attempts to go inside of `__torch_dispatch__`, it will find |
| 138 | +`jaten.add`. Then it will record that in the `fx.Graph`. |
| 139 | + |
| 140 | +Our backend will see the same ops but in a different namespace (`jaten`). |
| 141 | +That is fine as long as we know how to look up its implementation. |
| 142 | + |
| 143 | +Note: we probably also need to hook up gradients of custom ops via. `autograph.Function`. |
| 144 | + |
| 145 | + |
| 146 | +Pros / Cons: |
| 147 | +Haven't tried, don't know if it gonna work or not. |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | + |
| 152 | + |
| 153 | + |
| 154 | +# Appendix, Failed attempts: |
| 155 | + |
| 156 | +## Attempt 1: move dispatch to a mode (i.e. subclass have no dispatch override) |
| 157 | + |
| 158 | +```python |
| 159 | +class Subclass(torch.Tensor): |
| 160 | + |
| 161 | + @staticmethod |
| 162 | + def __new__(cls, elem): |
| 163 | + dtype = tensor.j2t_dtype(elem.dtype) |
| 164 | + shape = list(elem.shape) |
| 165 | + for i, s in enumerate(shape): |
| 166 | + if not isinstance(s, int): |
| 167 | + shape[i] = 1 |
| 168 | + if dtype is None: |
| 169 | + dtype = torch.float32 |
| 170 | + |
| 171 | + self = torch.Tensor._make_wrapper_subclass( |
| 172 | + cls, |
| 173 | + shape, |
| 174 | + dtype=dtype, |
| 175 | + device='meta', |
| 176 | + requires_grad=False, |
| 177 | + ) |
| 178 | + self._meta = torch.empty( |
| 179 | + shape, dtype=dtype, device='meta', requires_grad=False |
| 180 | + ) |
| 181 | + self._elem = elem |
| 182 | + return self |
| 183 | + |
| 184 | + def __init__(self, elem: jax.Array): |
| 185 | + super().__init__() |
| 186 | + self._elem = elem |
| 187 | + |
| 188 | + def __str__(self): |
| 189 | + return "Subclass({} {})".format(str(type(self._elem)), str(self._elem)) |
| 190 | + |
| 191 | +``` |
| 192 | + |
| 193 | +This fails with an error saying that exhausted subclasses and all the `__torch_dispatch__` returned `NotImplemented`. |
| 194 | + |
0 commit comments