You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This model `m` contains 2 parts: the weights that is stored inside of the model
105
123
and it's submodules (`nn.Linear`).
106
124
107
-
To execute this model with `torchax`; we need construct and run the model
108
-
under an `environment` that captures pytorch ops and swaps them with TPU equivalent.
109
-
110
-
To create this environment: use
125
+
To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
126
+
To enable this, use:
111
127
112
128
```python
113
129
import torchax
114
-
115
-
env = torchax.default_env()
130
+
torchax.enable_globally()
116
131
```
117
-
Then, execute the instantiation of the model, as well as evaluation of model,
118
-
using `env` as a context manager:
132
+
Then, a `jax` device will be available to use
119
133
120
134
```python
121
-
with env:
122
-
inputs = torch.randn(3, 3, 28, 28)
123
-
m = MyModel()
124
-
res = m(inputs)
125
-
print(type(res)) # outputs Tensor
135
+
inputs = torch.randn(3, 3, 28, 28, device='jax')
136
+
m = MyModel()
137
+
res = m(inputs)
138
+
print(type(res)) # outputs torchax.tensor.Tensor
126
139
```
127
140
128
-
You can also enable the environment globally with
129
-
```python
130
-
import torchax
131
-
132
-
torchax.enable_globally()
133
-
```
141
+
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
142
+
a `jax.Array`. You can inspect that jax array with `res.jax()`
134
143
135
-
Then everything afterwards is run with XLA.
136
144
137
145
## What is happening behind the scene:
138
146
139
-
When a torch op is executed inside of `env` context manager, we can swap out the
140
-
implementation of that op with a version that runs on TPU.
141
-
When a model's constructor runs, it will call some tensor constructor, such as
142
-
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. Those
143
-
ops are captured by `env` too and placed directly on TPU.
144
-
145
-
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
146
-
147
-
### What if I created model outside of `env`.
148
-
149
-
So if you have
150
-
151
-
```
152
-
m = MyModel()
153
-
```
154
-
outside of env, then regular torch ops will run when creating this model.
155
-
Then presumably the model's weights will be on CPU (as instances of `torch.Tensor`).
147
+
We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
156
148
157
-
To move this model into XLA device, one can use `env.to_xla()` function.
149
+
In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
150
+
implementation of that op written in Jax.
158
151
159
-
i.e.
160
-
```
161
-
m2 = env.to_xla(m)
162
-
inputs = env.to_xla(inputs)
152
+
When a model's constructor runs, it will call some tensor constructor, such as
153
+
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
154
+
will create an `torch.Tensor` subclass that contains a `jax.Array`.
163
155
164
-
with env:
165
-
res = m2(inputs)
166
-
```
156
+
Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
157
+
and wraps it back into `torch.Tensor` subclass.
167
158
168
-
NOTE that we also need to move inputs to xla using `.to_xla`.
169
-
`to_xla` works with all pytrees of `torch.Tensor`.
159
+
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
170
160
171
161
172
162
### Executing with jax.jit
173
163
174
-
The above script will execute the model using eager mode Jax as backend. This
175
-
does allow executing torch models on TPU, but is often slower than what we can
164
+
The above script will execute the model using eager mode Jax as backend. This
165
+
does allow executing torch models on TPU, but is often slower than what we can
176
166
achieve with `jax.jit`.
177
167
178
168
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
0 commit comments