Skip to content

Seeing a roughly 25% slow down with torch-xla=1.13 because of the addcdiv_ and addcmul_ ops #4213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
aws-rhsoln opened this issue Nov 17, 2022 · 4 comments · Fixed by #4447
Closed
Assignees

Comments

@aws-rhsoln
Copy link
Contributor

aws-rhsoln commented Nov 17, 2022

🐛 Bug

For BERT large model, we are seeing around 25% slow down with the latest torch-xla=1.13. The slow down seems to happen because of the addcdiv_ and addcmul_ operations transferring scalars synchronously to device. These are the operations which causes slow down:

  1. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
  2. p.data.addcdiv_(exp_avg, denom, value=-step_size)
    (Taken from: https://github.com/huggingface/transformers/blob/v4.6.0/src/transformers/optimization.py#L346)
    Note, here the scalars are not special scalars as the value is not 0 or 1 as in this case: Fix special scalar handling for addcdiv and addcmul #3953

As the model size grows the number of scalars transfered to device also grows and we end up with a profile that looks like this:

Screen Shot 2022-11-17 at 9 19 59 AM

## To Reproduce Run the following example:
t = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)
t.addcdiv_(t1, t2, value=0.1) 
xm.mark_step()

This should produce a transfer of scalar synchronously.

Steps to reproduce the behavior:

Expected behavior

We didn't see such a transfer with torch-xla 1.12.

Environment

  • Reproducible on XLA backend [CPU/TPU]:
  • torch_xla version: 1.13

Additional context

@JackCaoG
Copy link
Collaborator

Oh ok.. I think the problem is I added the option to transfer the data in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_util.cpp#L657 and a couple other places. When we migrate to LTC, the control of the scalar data transfer is on upstream and our async logic might get ignored.

@alanwaketan can you follow up? @aws-rhsoln the quick fix might be just revert my addcmul_ and addcdiv_ codegen change on your end, we can fix it properly in nightly.

@alanwaketan
Copy link
Collaborator

alanwaketan commented Jan 11, 2023

It looks like the way how we generate the IR for the scalar is different.
LTC uses:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/lazy_graph_executor.cpp#L485

And XLA uses:
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L227

The key difference is that:
XLA will use the DataCacheArena at the end: https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/lazy_graph_executor.cpp#L288, while LTC delegates to the backend for this behavior and then our backend implementation doesn't use the cache: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_backend_impl.cpp#L69.

I think a good fix is to fix upstream with our behavior, i.e., using the DataCacheArena while generating the IR for the scalar.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jan 14, 2023
Summary:
XLA expects GetIrValueForScalarFromCodegen to use DataCache such that not every scalar will request a data transfer to the backend device.

This needs pytorch/xla#4447 to verify.

Test Plan:
PJRT_DEVICE=CPU python xla/test/test_operations.py -v -k test_cached_addcdiv

Fixes pytorch/xla#4213.

Pull Request resolved: #92066
Approved by: https://github.com/JackCaoG
alanwaketan added a commit that referenced this issue Jan 14, 2023
Summary:
This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars.

This needs pytorch/pytorch#92066 to function.

Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv

Fixes #4213.
@alanwaketan
Copy link
Collaborator

@aws-rhsoln The problem should be fixed. Can you verify?

@aws-rhsoln
Copy link
Contributor Author

Will verify and update. Thanks for the fix!

ManfeiBai pushed a commit that referenced this issue Jan 19, 2023
Summary:
This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars.

This needs pytorch/pytorch#92066 to function.

Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv

Fixes #4213.
ManfeiBai pushed a commit that referenced this issue Jan 19, 2023
Summary:
This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars.

This needs pytorch/pytorch#92066 to function.

Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv

Fixes #4213.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants