-
Notifications
You must be signed in to change notification settings - Fork 524
Seeing a roughly 25% slow down with torch-xla=1.13 because of the addcdiv_ and addcmul_ ops #4213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Oh ok.. I think the problem is I added the option to transfer the data in https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_util.cpp#L657 and a couple other places. When we migrate to LTC, the control of the scalar data transfer is on upstream and our @alanwaketan can you follow up? @aws-rhsoln the quick fix might be just revert my |
It looks like the way how we generate the IR for the scalar is different. And XLA uses: The key difference is that: I think a good fix is to fix upstream with our behavior, i.e., using the DataCacheArena while generating the IR for the scalar. |
Summary: XLA expects GetIrValueForScalarFromCodegen to use DataCache such that not every scalar will request a data transfer to the backend device. This needs pytorch/xla#4447 to verify. Test Plan: PJRT_DEVICE=CPU python xla/test/test_operations.py -v -k test_cached_addcdiv Fixes pytorch/xla#4213. Pull Request resolved: #92066 Approved by: https://github.com/JackCaoG
Summary: This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars. This needs pytorch/pytorch#92066 to function. Test Plan: PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv Fixes #4213.
@aws-rhsoln The problem should be fixed. Can you verify? |
Will verify and update. Thanks for the fix! |
Summary: This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars. This needs pytorch/pytorch#92066 to function. Test Plan: PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv Fixes #4213.
Summary: This pull request redoes the addcdiv and addcmul code-gen, and adds a test case to verify that if we reuse the DataCache for scalars. This needs pytorch/pytorch#92066 to function. Test Plan: PJRT_DEVICE=CPU python test/test_operations.py -v -k test_cached_addcdiv Fixes #4213.
🐛 Bug
For BERT large model, we are seeing around 25% slow down with the latest torch-xla=1.13. The slow down seems to happen because of the addcdiv_ and addcmul_ operations transferring scalars synchronously to device. These are the operations which causes slow down:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
p.data.addcdiv_(exp_avg, denom, value=-step_size)
(Taken from: https://github.com/huggingface/transformers/blob/v4.6.0/src/transformers/optimization.py#L346)
Note, here the scalars are not special scalars as the value is not 0 or 1 as in this case: Fix special scalar handling for addcdiv and addcmul #3953
As the model size grows the number of scalars transfered to device also grows and we end up with a profile that looks like this:

## To Reproduce Run the following example:This should produce a transfer of scalar synchronously.
Steps to reproduce the behavior:
Expected behavior
We didn't see such a transfer with torch-xla 1.12.
Environment
Additional context
The text was updated successfully, but these errors were encountered: