Skip to content

Revert PR#3450 and use sparse_gather in gather #3566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2022

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented May 13, 2022

Previously, we've internally overridden the sparse_grad option in torch.gather (see #3450 ). However, forcing sparse_grad=true always can have a serious performance implications, like #3441 . This PR addresses the issue and honors the sparse_grad option as passed with the original torch.gather calls.

cc hi @ymwangg, with this change you would need to pass sparse_grad=true explicitly to torch.gather as described in torch.gather. Let us know if you run into any issues after we revert #3450
cc @ronghanghu FYI

XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::gather(
Copy link
Contributor Author

@yeounoh yeounoh May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto --> re-ran clang-format-7, nothing changed. Will leave it for now.

@ronghanghu
Copy link
Collaborator

Thanks for fixing this!

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @yeounoh! Will let the CI do the rest.

@@ -1403,7 +1403,7 @@ XLATensor XLATensor::full_like(const XLATensor& input,
}

XLATensor XLATensor::gather(const XLATensor& input, int64_t dim,
const XLATensor& index) {
const XLATensor& index, bool sparse_grad) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to pass sparse_grad to line 1418?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return input.CreateFrom(torch::lazy::MakeNode<Gather>(
      input.GetIrValue(), canonical_dim, index.GetIrValue(), sparse_grad));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's being passed already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's not picked in this diff?! My local branch is up to date and shows the correct one -- but the github file view doesn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, VS CODE didn't write out the change to the file, resyncing and recommiting. Thanks @miladm

@JackCaoG
Copy link
Collaborator

Did we verify that this change("using the sparse passed from pytorch") has the same speed as we use the IsSparse helper function?

@miladm
Copy link
Collaborator

miladm commented May 13, 2022

+1
To augment our testing, @ronghanghu, @ymwangg would you like to try this PR on your end?

@ymwangg
Copy link
Contributor

ymwangg commented May 13, 2022

@yeounoh Thanks for fixing the bug. It looks like my previous PR failed to generalize to other models. I'll revisit it on my side.

@yeounoh
Copy link
Contributor Author

yeounoh commented May 13, 2022

Thanks @ymwangg , let me know if you need any help. @ronghanghu I can test your mode on TPU for you.

@ronghanghu
Copy link
Collaborator

Thanks, @yeounoh! (I guess I cannot try it on my end now since I don't have a way to build it locally)

@yeounoh
Copy link
Contributor Author

yeounoh commented May 14, 2022

Thanks, @yeounoh! (I guess I cannot try it on my end now since I don't have a way to build it locally)

Yea, I just verified that the patch doesn't change the performance -- and with the default sparse_grad=false.

@ymwangg
Copy link
Contributor

ymwangg commented May 14, 2022

@yeounoh According to the PyTorch documentation, sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor. It seems this sparse_grad option controls whether the input gradient tensor should be sparse tensor or not and is not directly tied to the sparse option in xla::TorchGather. It's my understanding that sparse option in xla::TorchGather controls which underlying gather algorithm to choose and sparse=True should pick the algorithm that is supposed to do better when the index tensor size is small.

I tested the script provided by @ronghanghu on Nvidia V100 GPU and the speed is roughly the same whether using sparse=True or not.
With sparse=False:

[14:20:14.043412] Epoch: [0]  [5000/5004]  eta: 0:00:02  lr: 0.000000  loss: 0.7167 (0.7169)  time: 0.4813  data: 0.1547
[14:20:15.670125] Epoch: [0]  [5003/5004]  eta: 0:00:00  lr: 0.000000  loss: 0.7168 (0.7169)  time: 0.4796  data: 0.1566
[14:20:16.230841] Epoch: [0] Total time: 0:46:24 (0.5565 s / it)

With sparse=True:

[15:19:29.002305] Epoch: [0]  [5000/5004]  eta: 0:00:02  lr: 0.000000  loss: 0.7167 (0.7169)  time: 0.4710  data: 0.1640
[15:19:30.502887] Epoch: [0]  [5003/5004]  eta: 0:00:00  lr: 0.000000  loss: 0.7168 (0.7169)  time: 0.4676  data: 0.1625
[15:19:31.105321] Epoch: [0] Total time: 0:45:38 (0.5473 s / it)

I examined the model and found the following three gather ops:

%32 = s64[32,49,1024]{2,1,0} aten::repeat(%31), repeats=(1, 1, 1024)
%44 = f32[32,196,1024]{2,1,0} aten::add(%43, %39)
%45 = f32[32,49,1024]{2,1,0} aten::gather(%44, %32), dim=1

%3515 = s64[32,196,512]{2,1,0} aten::repeat(%3514), repeats=(1, 1, 512)
%3532 = f32[32,196,512]{2,1,0} aten::cat(%3531, %3517), dim=1 
%3533 = f32[32,196,512]{2,1,0} aten::gather(%3532, %3515), dim=1 

%3513 = (s64[32,196]{1,0}, s64[32,196]{1,0}) aten::topk(%28.1), num_outputs=2, k=196, dim=1, largest=0, sorted=1, stable=0
%4742 = f32[32,196]{1,0} xla::unselect(%4739, %4741), dim=0, start=0, end=32, stride=1
%4743 = f32[32,196]{1,0} aten::gather(%4742, %3513.1), dim=1, ROOT=1585

I then profiled the impact of sparse=True on these three gather ops on both CPU and GPU platform using the xla/tools/replay_computation_cpu/gpu tool and sparse=True consistently gives better performance for all the test cases:

Platform gather0 sparse=False gather0 sparse=True gather1 sparse=False gather1 sparse=True gather2 sparse=False gather2 sparse=True
GPU 2694 47 3285 73 22 17
CPU 173476 803 348241 1705 1231 22

unit: us

The HLO used for profiling is generated by the following script:

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import os

torch.manual_seed(1)
os.environ['GPU_NUM_DEVICES'] = '1'

def gather0(device):
    a = torch.randint(0, 10, size=(32, 49, 1024), dtype=torch.int64).to(device)
    b = torch.rand(32,196,1024, dtype=torch.float32).to(device)
    c = torch.gather(b, dim=1, index=a)
    print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
    return c

def gather1(device):
    a = torch.randint(0, 10, size=(32, 196, 512), dtype=torch.int64).to(device)
    b = torch.rand(32,196,512, dtype=torch.float32).to(device)
    c = torch.gather(b, dim=1, index=a)
    print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
    return c

def gather2(device):
    a = torch.randint(0, 10, size=(32, 196), dtype=torch.int64).to(device)
    b = torch.rand(32,196, dtype=torch.float32).to(device)
    c = torch.gather(b, dim=1, index=a)
    print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
    return c

device = xm.xla_device()
xla_result = gather0(device).cpu().numpy()
xla_result = gather1(device).cpu().numpy()
xla_result = gather2(device).cpu().numpy()

It's also odd that this performance issue is more significant when using TPU pod instead of a single TPU according to @ronghanghu, as torch::gather seems irrelevant to muti-node communications. It would be great if someone can shed some light on it.

As the data suggest sparse=True is better on GPU/CPU while it makes things worse on TPU, I would suggest considering making it default only when platform == GPU/CPU.

@yeounoh yeounoh self-assigned this May 16, 2022
@JackCaoG
Copy link
Collaborator

I think we should merge this one first and follow up(and make it to the 1.12 release) on how to set the default sparse. Pervious one introduce a regression on gather and there is no way for user to overwrite that flag.

@yeounoh
Copy link
Contributor Author

yeounoh commented May 16, 2022

Hi @ymwangg thanks for the follow up and the detailed analysis. TL;DR, thanks for spotting that the two sparsity options need to be detached. Based on the following observations, I think this issue is something we would want to address soon and with a separate PR:

  • on GPU/CPU, sparse=True works better for both sparse and dense tensors
  • on TPU, it's the opposite sparse=False
  • the original gather had some performance issue with loss.backward, and thus, the sparsity heuristic was used with configurable threshold

In general gather of a single float is very slow on TPU, and I suspect that's not the case on other platforms. Let us discuss this further internally and address your issue in a separate PR. Let's use all our findings to come up with a better heuristics than setting sparse=True always. Thank you @ymwangg -- I will follow up with you as soon as possible.

@ymwangg
Copy link
Contributor

ymwangg commented May 16, 2022

@JackCaoG @yeounoh That sounds really good! Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants