-
Notifications
You must be signed in to change notification settings - Fork 524
the new xm.optimization_barrier
API breaks the gradient flow
#3486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Update: I tried digging into it further. I think the root cause of the problem above is that in xla/torch_xla/csrc/init_python_bindings.cpp Lines 284 to 287 in cf19c0c
GetXlaTensor and make_variable break the PyTorch autograd graph -- they essentially work similar to .detach() in PyTorch and remove the autograd reference.
After I change it to
in test_train_mp_mnist_with_optimization_barrier.py, then the gradient flow is good and I can get
which matches the case without using (However, this seems to be an unintuitive API and error-prone. It would be better to provide some documentation and/or warnings or a more intuitive interface.) |
xm.optimization_barrier
API produces incorrect resultsxm.optimization_barrier
API breaks the gradient flow
@bdhirsh What's the correct way of make a new PyTorch tensor in C++ while not breaking the auto-grad? |
Hmmm. If I'm understanding correctly, the "normal" way to do this would be through a native pytorch op. For example, take
And it sounds like this @JackCaoG instead of having |
Good point.. I think one thing I can do is to make it an in-place operation. I was referring to JAX example but then I realized JAX is purely functional hence no in place update(I think). |
Some update and finding, I submitted #3493 to fix the gradient not being compute issue but it does not solve the memory usage. I think I am using barrier incorrectly. Here is a JAX example
Note that in above function
I need to spend some time to figure out how to do 1, since for pytorch this re-compuation is controlled by the framework not the user. |
I also need to change the current |
@JackCaoG Thanks for the update! For (1) above, I think it's fine to accomplish gradient checkpointing in another way other than calling For example, the FairScale library has its own custom wrapper |
mostly done with the |
WIth #3493 and a patch to pytorch file
I am able to see with gradient checkpointing turn on
with it turn off
With checkpointing, memory usage is lower but execution takes longer. This is expected. I will post my finding to the github issue. User do not need to call |
Let me also explained a bit on what is going on. I will use
The first optimization barrier is
right after this barrier there is a call
Note that this
The difference is repeated call is being done on the
compiler will not try to fuse these two identical call |
This is awesome! I confirm that gradient checkpointing now also works well on my end with this new optimization barrier API and the patched torch/utils/checkpoint.py above.
This is great to know. I'm now trying it out in my FSDP cases to prevent the fusion of full parameter gathering and freeing in #3431. |
🐛 Bug
The new
xm.optimization_barrier
API introduced in #3482 provides a great feature to avoid XLA compiler fusion between different parts of the graph (e.g. forward pass and backward pass) -- very useful for gradient checkpointing application such as in #3455.However, applying the
xm.optimization_barrier
API leads to incorrect results in many cases. So it seems that a further inspection is needed here.For example, it breaks the MNIST example. In a correct training case, MNIST is supposed to get 98%+ accuracy in 2 epochs. However, when calling
output, = xm.optimization_barrier([output])
on the model output with this API, the MNIST training does not converge. In fact, the training doesn't happen at all as all the model parameters'.grad
is alwaysNone
in this case.To Reproduce
tpu-vm-pt-1.10
runtime environment.test_train_mp_mnist.py
and download the new API exampletest_train_mp_mnist_with_optimization_barrier.py
:Note: their only difference is that
test_train_mp_mnist_with_optimization_barrier.py
hasoutput, = xm.optimization_barrier([output])
on the model output.--batch_size 16 --drop_last --num_epochs 2
and check their training accuracies.The official PyTorch XLA MNIST example with 2 training epochs
gives
as expected.
The new API example with 2 training epochs
gives
which shows that the model doesn't converge.
It seems that this new
xm.optimization_barrier
API breaks the gradient flow -- the accuracy at epoch 1 and epoch 2 are both exactly 3.58. A further inspection shows that all the model parameters stayed the same as their initialized values and their.grad
is alwaysNone
.Expected behavior
The training accuracy should be the same between the two cases since
xm.optimization_barrier
should not change the computational results.Environment
cc: @JackCaoG @ultrons
The text was updated successfully, but these errors were encountered: