-
Notifications
You must be signed in to change notification settings - Fork 524
Revert PR#3450 and use sparse_gather in gather #3566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
XLA_FN_COUNTER("xla::"); | ||
return bridge::AtenFromXlaTensor(XLATensor::gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto --> re-ran clang-format-7, nothing changed. Will leave it for now.
Thanks for fixing this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @yeounoh! Will let the CI do the rest.
@@ -1403,7 +1403,7 @@ XLATensor XLATensor::full_like(const XLATensor& input, | |||
} | |||
|
|||
XLATensor XLATensor::gather(const XLATensor& input, int64_t dim, | |||
const XLATensor& index) { | |||
const XLATensor& index, bool sparse_grad) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need to pass sparse_grad
to line 1418
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return input.CreateFrom(torch::lazy::MakeNode<Gather>(
input.GetIrValue(), canonical_dim, index.GetIrValue(), sparse_grad));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's being passed already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it's not picked in this diff?! My local branch is up to date and shows the correct one -- but the github file view doesn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, VS CODE didn't write out the change to the file, resyncing and recommiting. Thanks @miladm
Did we verify that this change("using the sparse passed from pytorch") has the same speed as we use the |
+1 |
@yeounoh Thanks for fixing the bug. It looks like my previous PR failed to generalize to other models. I'll revisit it on my side. |
Thanks @ymwangg , let me know if you need any help. @ronghanghu I can test your mode on TPU for you. |
Thanks, @yeounoh! (I guess I cannot try it on my end now since I don't have a way to build it locally) |
Yea, I just verified that the patch doesn't change the performance -- and with the default |
@yeounoh According to the PyTorch documentation, I tested the script provided by @ronghanghu on Nvidia V100 GPU and the speed is roughly the same whether using [14:20:14.043412] Epoch: [0] [5000/5004] eta: 0:00:02 lr: 0.000000 loss: 0.7167 (0.7169) time: 0.4813 data: 0.1547
[14:20:15.670125] Epoch: [0] [5003/5004] eta: 0:00:00 lr: 0.000000 loss: 0.7168 (0.7169) time: 0.4796 data: 0.1566
[14:20:16.230841] Epoch: [0] Total time: 0:46:24 (0.5565 s / it) With [15:19:29.002305] Epoch: [0] [5000/5004] eta: 0:00:02 lr: 0.000000 loss: 0.7167 (0.7169) time: 0.4710 data: 0.1640
[15:19:30.502887] Epoch: [0] [5003/5004] eta: 0:00:00 lr: 0.000000 loss: 0.7168 (0.7169) time: 0.4676 data: 0.1625
[15:19:31.105321] Epoch: [0] Total time: 0:45:38 (0.5473 s / it) I examined the model and found the following three gather ops: %32 = s64[32,49,1024]{2,1,0} aten::repeat(%31), repeats=(1, 1, 1024)
%44 = f32[32,196,1024]{2,1,0} aten::add(%43, %39)
%45 = f32[32,49,1024]{2,1,0} aten::gather(%44, %32), dim=1
%3515 = s64[32,196,512]{2,1,0} aten::repeat(%3514), repeats=(1, 1, 512)
%3532 = f32[32,196,512]{2,1,0} aten::cat(%3531, %3517), dim=1
%3533 = f32[32,196,512]{2,1,0} aten::gather(%3532, %3515), dim=1
%3513 = (s64[32,196]{1,0}, s64[32,196]{1,0}) aten::topk(%28.1), num_outputs=2, k=196, dim=1, largest=0, sorted=1, stable=0
%4742 = f32[32,196]{1,0} xla::unselect(%4739, %4741), dim=0, start=0, end=32, stride=1
%4743 = f32[32,196]{1,0} aten::gather(%4742, %3513.1), dim=1, ROOT=1585 I then profiled the impact of
unit: us The HLO used for profiling is generated by the following script: import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import os
torch.manual_seed(1)
os.environ['GPU_NUM_DEVICES'] = '1'
def gather0(device):
a = torch.randint(0, 10, size=(32, 49, 1024), dtype=torch.int64).to(device)
b = torch.rand(32,196,1024, dtype=torch.float32).to(device)
c = torch.gather(b, dim=1, index=a)
print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
return c
def gather1(device):
a = torch.randint(0, 10, size=(32, 196, 512), dtype=torch.int64).to(device)
b = torch.rand(32,196,512, dtype=torch.float32).to(device)
c = torch.gather(b, dim=1, index=a)
print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
return c
def gather2(device):
a = torch.randint(0, 10, size=(32, 196), dtype=torch.int64).to(device)
b = torch.rand(32,196, dtype=torch.float32).to(device)
c = torch.gather(b, dim=1, index=a)
print(torch_xla._XLAC._get_xla_tensors_hlo([c]))
return c
device = xm.xla_device()
xla_result = gather0(device).cpu().numpy()
xla_result = gather1(device).cpu().numpy()
xla_result = gather2(device).cpu().numpy() It's also odd that this performance issue is more significant when using TPU pod instead of a single TPU according to @ronghanghu, as As the data suggest |
I think we should merge this one first and follow up(and make it to the 1.12 release) on how to set the default |
Hi @ymwangg thanks for the follow up and the detailed analysis. TL;DR, thanks for spotting that the two sparsity options need to be detached. Based on the following observations, I think this issue is something we would want to address soon and with a separate PR:
In general gather of a single float is very slow on TPU, and I suspect that's not the case on other platforms. Let us discuss this further internally and address your issue in a separate PR. Let's use all our findings to come up with a better heuristics than setting |
Previously, we've internally overridden the
sparse_grad
option intorch.gather
(see #3450 ). However, forcingsparse_grad=true
always can have a serious performance implications, like #3441 . This PR addresses the issue and honors thesparse_grad
option as passed with the originaltorch.gather
calls.cc hi @ymwangg, with this change you would need to pass
sparse_grad=true
explicitly totorch.gather
as described intorch.gather
. Let us know if you run into any issues after we revert #3450cc @ronghanghu FYI