Skip to content

the new xm.optimization_barrier API breaks the gradient flow #3486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ronghanghu opened this issue Apr 9, 2022 · 11 comments
Closed

the new xm.optimization_barrier API breaks the gradient flow #3486

ronghanghu opened this issue Apr 9, 2022 · 11 comments
Assignees

Comments

@ronghanghu
Copy link
Collaborator

ronghanghu commented Apr 9, 2022

🐛 Bug

The new xm.optimization_barrier API introduced in #3482 provides a great feature to avoid XLA compiler fusion between different parts of the graph (e.g. forward pass and backward pass) -- very useful for gradient checkpointing application such as in #3455.

However, applying the xm.optimization_barrier API leads to incorrect results in many cases. So it seems that a further inspection is needed here.

For example, it breaks the MNIST example. In a correct training case, MNIST is supposed to get 98%+ accuracy in 2 epochs. However, when calling output, = xm.optimization_barrier([output]) on the model output with this API, the MNIST training does not converge. In fact, the training doesn't happen at all as all the model parameters' .grad is always None in this case.

To Reproduce

  1. Get a v3-8 TPU VM with tpu-vm-pt-1.10 runtime environment.
  2. Install the nightly PyTorch XLA build containing this API:
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
  1. Clone PyTorch XLA repo containing the official PyTorch XLA MNIST example test_train_mp_mnist.py and download the new API example test_train_mp_mnist_with_optimization_barrier.py:
git clone https://github.com/pytorch/xla.git
cd xla/test

# test_train_mp_mnist_with_optimization_barrier.py -- new API example
wget https://gist.githubusercontent.com/ronghanghu/74f103f79df3c5c6df2807d12506d6c7/raw/265f4377fb051ba7f799184e9a70bcffac8a1cde/test_train_mp_mnist_with_optimization_barrier.py

Note: their only difference is that test_train_mp_mnist_with_optimization_barrier.py has output, = xm.optimization_barrier([output]) on the model output.

ronghanghu@t1v-n-d5308e1f-w-0:~$ diff -p test_train_mp_mnist.py test_train_mp_mnist_with_optimization_barrier.py
*** test_train_mp_mnist.py      2022-04-09 06:49:39.712446418 +0000
--- test_train_mp_mnist_with_optimization_barrier.py    2022-04-09 06:49:44.728872258 +0000
*************** def train_mnist(flags, **kwargs):
*** 125,130 ****
--- 125,131 ----
      for step, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
+       output, = xm.optimization_barrier([output])
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer)
  1. Run the two examples above with --batch_size 16 --drop_last --num_epochs 2 and check their training accuracies.

The official PyTorch XLA MNIST example with 2 training epochs

python3 -u test_train_mp_mnist.py --batch_size 16 --drop_last --num_epochs 2

gives

...
Epoch 1 test end 07:03:06, Accuracy=98.68
...
Epoch 2 test end 06:59:32, Accuracy=98.94                                                                                
Max Accuracy: 98.94%

as expected.

The new API example with 2 training epochs

python3 -u test_train_mp_mnist_with_optimization_barrier.py --batch_size 16 --drop_last --num_epochs 2

gives

...
Epoch 1 test end 07:00:38, Accuracy=3.58
...
Epoch 2 test end 07:00:48, Accuracy=3.58
Max Accuracy: 3.58%                                                                                                      
Accuracy 3.58375 is below target 98.0

which shows that the model doesn't converge.

It seems that this new xm.optimization_barrier API breaks the gradient flow -- the accuracy at epoch 1 and epoch 2 are both exactly 3.58. A further inspection shows that all the model parameters stayed the same as their initialized values and their .grad is always None.

Expected behavior

The training accuracy should be the same between the two cases since xm.optimization_barrier should not change the computational results.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
  • torch_xla version: nightly+20220408 (see details above)

cc: @JackCaoG @ultrons

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 9, 2022

Update: I tried digging into it further. I think the root cause of the problem above is that in

result.push_back(torch::autograd::make_variable(
bridge::AtenFromXlaTensor(
XLATensor::optimization_barrier(bridge::GetXlaTensor(tensor))),
/*requires_grad=*/tensor.requires_grad()));
, GetXlaTensor and make_variable break the PyTorch autograd graph -- they essentially work similar to .detach() in PyTorch and remove the autograd reference.


After I change it to

output.data, = xm.optimization_barrier([output])

in test_train_mp_mnist_with_optimization_barrier.py, then the gradient flow is good and I can get

Epoch 2 test end 07:33:09, Accuracy=98.94                                                                                
Max Accuracy: 98.94% 

which matches the case without using xm.optimization_barrier.

(However, this seems to be an unintuitive API and error-prone. It would be better to provide some documentation and/or warnings or a more intuitive interface.)

@ronghanghu ronghanghu changed the title the new xm.optimization_barrier API produces incorrect results the new xm.optimization_barrier API breaks the gradient flow Apr 9, 2022
@JackCaoG
Copy link
Collaborator

@bdhirsh What's the correct way of make a new PyTorch tensor in C++ while not breaking the auto-grad?

@bdhirsh
Copy link
Collaborator

bdhirsh commented Apr 11, 2022

Hmmm. If I'm understanding correctly, the "normal" way to do this would be through a native pytorch op. For example, take at::add:

b = torch.add(a, ...)
-> XLA kernel: create a tensor "b", give it an "add" node in the XLA graph
-> autograd kernel: take the "b" created from the XLA kernel, give it an "add" node in the autograd graph

And it sounds like this optimization_barrier node is like an operator from the perspective of XLA, since it creates a new tensor IR node with an "optimization_barrier" reference.

@JackCaoG instead of having optimization_barrier generate a new tensor, do you think it would be reasonable to mutate the IR of the input tensor to add the optimization_barrier node? Or if it's important to return a new tensor, you could clone() in the input tensors (giving you a new tensor with proper autograd tracking) and mutate the new XLA IR directly onto them.

@JackCaoG
Copy link
Collaborator

Good point.. I think one thing I can do is to make it an in-place operation. I was referring to JAX example but then I realized JAX is purely functional hence no in place update(I think).

@JackCaoG JackCaoG self-assigned this Apr 11, 2022
@JackCaoG
Copy link
Collaborator

Some update and finding, I submitted #3493 to fix the gradient not being compute issue but it does not solve the memory usage. I think I am using barrier incorrectly.

Here is a JAX example

y_primal = f(x)
... compute forward pass ...
x2, cotangent_in2 = optimization_barrier(x, cotangent_in)
y_remat = f(x2)
return gradient(y_remat, cotangent_in2)

Note that in above function f is being called twice to save the peak memory. In order to make similar things to work for pytorch/xla I think we need to do

  1. when doing the remateralization of the forward, we need to use the input that is with optimization_barrier
  2. Input needs to call optimization_barrier with some kind of value that is the result of the first forward pass

2 is needed because we bounded the creation between x2 and cotangent_in2, cotangent_in2 won't be available until the first forward pass is completed which guarantee the order of

first forwad
barrier
gradient computation 
backward

1 served the similar purpose that force second forward happened after the barrier which should precent CSE to fuse the repeated compuation.

I need to spend some time to figure out how to do 1, since for pytorch this re-compuation is controlled by the framework not the user.

@JackCaoG
Copy link
Collaborator

I also need to change the current optimzation_barrier API to tuplify the inputs before passing to Xla::OptimziationBarrier and untuplify the output before returning to user which will force Xla::OptimziationBarrier happened after all inputs is ready.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 14, 2022

@JackCaoG Thanks for the update! For (1) above, I think it's fine to accomplish gradient checkpointing in another way other than calling torch.utils.checkpoint.checkpoint.

For example, the FairScale library has its own custom wrapper fairscale.nn.checkpoint.checkpoint_wrapper and we can have a separate wrapper for XLA as well if the current PyTorch torch.utils.checkpoint.checkpoint cannot handle the XLA case.

@JackCaoG
Copy link
Collaborator

mostly done with the tuplify the inputs part of the optimzation_barrier. Will try to wrap it up tmr and modifying torch.utils.checkpoint.checkpoint directly to see if I can get the expected hlo.

@JackCaoG
Copy link
Collaborator

WIth #3493 and a patch to pytorch file

--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -1,5 +1,9 @@
 import torch
 import warnings
+import torch_xla.core.xla_model as xm
+from collections.abc import Iterable
+
 from typing import Any, Iterable, List, Tuple, Union
 
 
@@ -82,6 +86,8 @@ class CheckpointFunction(torch.autograd.Function):
         ctx.inputs = []
         ctx.tensor_indices = []
         tensor_inputs = []
+        tensor_outputs = []
         for i, arg in enumerate(args):
             if torch.is_tensor(arg):
                 tensor_inputs.append(arg)
@@ -90,10 +96,23 @@ class CheckpointFunction(torch.autograd.Function):
             else:
                 ctx.inputs.append(arg)
 
-        ctx.save_for_backward(*tensor_inputs)
-
         with torch.no_grad():
             outputs = run_function(*args)
+
+        if torch.is_tensor(outputs):
+            tensor_outputs.append(outputs)
+        # tensor is Iterable so we need to avoid iterating through tensor
+        elif isinstance(outputs, Iterable):
+            for output in outputs:
+                if torch.is_tensor(output):
+                    tensor_outputs.append(output)
+                    
+        xm.optimization_barrier_(tensor_inputs + tensor_outputs)
+        ctx.save_for_backward(*tensor_inputs)
+
+        # with torch.no_grad():
+        #     outputs = run_function(*args)
         return outputs

I am able to see with gradient checkpointing turn on

step 18, free memory after forward = 12901312
step 18, free memory = 12901312
step 19, free memory after forward = 12901312
step 19, free memory = 12901312
Metric: ExecuteTime
  TotalSamples: 20
  Accumulator: 06s542ms653.719us
  ValueRate: 154ms085.837us / second
  Rate: 0.556101 / second
  Percentiles: 1%=267ms809.712us; 5%=267ms813.576us; 10%=267ms031.572us; 20%=267ms153.407us; 50%=268ms569.989us; 80%=282ms100.887us; 90%=313ms612.029us; 95%=347ms956.749us; 99%=347ms956.749us
Metric: InputOutputAliasCount

with it turn off

step 18, free memory after forward = 11514304
step 18, free memory = 11514304
step 19, free memory after forward = 11514304
step 19, free memory = 11514304
Metric: ExecuteTime
  TotalSamples: 20
  Accumulator: 04s355ms649.055us
  ValueRate: 132ms943.536us / second
  Rate: 0.605989 / second
  Percentiles: 1%=205ms534.556us; 5%=205ms631.425us; 10%=205ms028.855us; 20%=205ms278.535us; 50%=207ms046.007us; 80%=222ms446.996us; 90%=265ms024.504us; 95%=288ms697.969us; 99%=288ms697.969us

With checkpointing, memory usage is lower but execution takes longer. This is expected. I will post my finding to the github issue.

User do not need to call optimzation_barrier since it is being done by the CheckpointFunction

@JackCaoG
Copy link
Collaborator

Let me also explained a bit on what is going on. I will use ir text instead of hlo since it is more clean

%14 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%7, %8, %13), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %15 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%14), [email protected]:1406
  %16 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%15, %12, %11), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %17 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%16), [email protected]:1406
  %18 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%17, %10, %9), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %19 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%18), [email protected]:1406
  %20 = (f32[64,1024,14,14]{3,2,1,0}, f32[64,1024,14,14]{3,2,1,0}) xla::optimization_barrier(%15, %19), num_outputs=2, location=optimization_barrier_@xla_model.py:1038
  %21 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%20.0, %12, %11), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %22 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%21), [email protected]:1406
  %23 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%22, %10, %9), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %24 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%23), [email protected]:1406
  %25 = f32[1024,1024,1,1]{1,0,3,2} xla::device_data(), [email protected]:905, device=TPU:0
  %26 = f32[1024]{0} xla::device_data(), [email protected]:905, device=TPU:0
  %27 = f32[1024,1024,1,1]{1,0,3,2} xla::device_data(), [email protected]:905, device=TPU:0
  %28 = f32[1024]{0} xla::device_data(), [email protected]:905, device=TPU:0
  %29 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%20.1, %25, %28), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %30 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%29), [email protected]:1406
  %31 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%30, %27, %26), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %32 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%31), [email protected]:1406
  %33 = (f32[64,1024,14,14]{3,2,1,0}, f32[64,1024,14,14]{3,2,1,0}) xla::optimization_barrier(%20.1, %32), num_outputs=2, location=optimization_barrier_@xla_model.py:1038
  %34 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%33.0, %25, %28), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %35 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%34), [email protected]:1406
  %36 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%35, %27, %26), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %37 = f32[64,1024,14,14]{3,2,1,0} aten::relu(%36), [email protected]:1406
  %38 = f32[] prim::Constant(), location=_make_grads@__init__.py:68, value=1
  %39 = f32[64,1024,14,14]{3,2,1,0} aten::expand(%38), size=(64, 1024, 14, 14)
  %40 = f32[64,1024,14,14]{3,2,1,0} aten::threshold_backward(%39, %37), location=backward@__init__.py:173, threshold=0
  %41 = (f32[64,1024,14,14]{3,2,1,0}, f32[1024,1024,1,1]{0,1,3,2}, f32[1024]{0}) aten::convolution_backward_overrideable(%40, %35, %27), num_outputs=3, location=backward@__init__.py:173, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=16

The first optimization barrier is

  %20 = (f32[64,1024,14,14]{3,2,1,0}, f32[64,1024,14,14]{3,2,1,0}) xla::optimization_barrier(%15, %19), num_outputs=2, location=optimization_barrier_@xla_model.py:1038

%15 is the input to the checkpointing function and %19 is the output of the checkpointing function. checkpointing function in this case is 2 conv + 2 relu. What xm.optimization_barrier_ does is it bind the input and output of the function together and apply a barrier to them .

right after this barrier there is a call

  %21 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%20.0, %12, %11), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1

Note that this convolution_overrideable is a repeated call as

  %16 = f32[64,1024,14,14]{3,2,1,0} aten::convolution_overrideable(%15, %12, %11), [email protected]:443, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1

The difference is repeated call is being done on the barrierd_input %20.0 instead of origional input %15. Without barrier these call are identical and cse will fused them together. However now there is a barrier

%15 -> barrier -> %20.0

compiler will not try to fuse these two identical call

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 15, 2022

This is awesome! I confirm that gradient checkpointing now also works well on my end with this new optimization barrier API and the patched torch/utils/checkpoint.py above.

What xm.optimization_barrier_ does is it bind the input and output of the function together and apply a barrier to them.

This is great to know. I'm now trying it out in my FSDP cases to prevent the fusion of full parameter gathering and freeing in #3431.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants