Skip to content

Performance regression after migrating to LTC codegen (addcdiv, addcmul) #3942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ymwangg opened this issue Aug 26, 2022 · 2 comments
Closed
Assignees

Comments

@ymwangg
Copy link
Contributor

ymwangg commented Aug 26, 2022

🐛 Bug

We noticed ~18% performance drop in BERT model after #3768. It looks like this issue is due to a new flag in upstream LTC not being enabled by default here. This special scalar check is important for XLA to optimize ops like torch.addcdiv(a, b, c, value=1.0), torch.add(a, b, alpha=1.0) with constant folding.

To Reproduce

import torch
import torch_xla.core.xla_model as xm
import torch_xla
import torch_xla.debug.metrics as met

device = xm.xla_device()

a = torch.rand(10).to(device)
b = torch.rand(10).to(device)
c = torch.rand(10).to(device)
d = torch.addcdiv(a, b, c, value=1.0)
print(torch_xla._XLAC._get_xla_tensors_hlo([d]))

HLO dump:

HloModule IrToHlo.10, entry_computation_layout={(f32[],f32[10]{0},f32[10]{0},f32[10]{0})->(f32[10]{0})}

ENTRY %IrToHlo.10 (p0.1: f32[], p1.2: f32[10], p2.3: f32[10], p3.4: f32[10]) -> (f32[10]) {
  %p3.4 = f32[10]{0} parameter(3)
  %p2.3 = f32[10]{0} parameter(2)
  %p1.2 = f32[10]{0} parameter(1)
  %divide.5 = f32[10]{0} divide(f32[10]{0} %p2.3, f32[10]{0} %p1.2)
  %p0.1 = f32[] parameter(0)
  %broadcast.6 = f32[10]{0} broadcast(f32[] %p0.1), dimensions={}
  %multiply.7 = f32[10]{0} multiply(f32[10]{0} %divide.5, f32[10]{0} %broadcast.6)
  %add.8 = f32[10]{0} add(f32[10]{0} %p3.4, f32[10]{0} %multiply.7)
  ROOT %tuple.9 = (f32[10]{0}) tuple(f32[10]{0} %add.8)
}

Note setting torch_lazy_handle_special_scalars=True solves the special scalar problem, but the result is improperly casted to fp64:

HloModule IrToHlo.12, entry_computation_layout={(f32[10]{0},f32[10]{0},f32[10]{0})->(f64[10]{0})}

ENTRY %IrToHlo.12 (p0.2: f32[10], p1.3: f32[10], p2.4: f32[10]) -> (f64[10]) {
  %p2.4 = f32[10]{0} parameter(2)
  %convert.9 = f64[10]{0} convert(f32[10]{0} %p2.4)
  %p1.3 = f32[10]{0} parameter(1)
  %p0.2 = f32[10]{0} parameter(0)
  %divide.5 = f32[10]{0} divide(f32[10]{0} %p1.3, f32[10]{0} %p0.2)
  %convert.6 = f64[10]{0} convert(f32[10]{0} %divide.5)
  %constant.1 = f64[] constant(1)
  %broadcast.7 = f64[10]{0} broadcast(f64[] %constant.1), dimensions={}
  %multiply.8 = f64[10]{0} multiply(f64[10]{0} %convert.6, f64[10]{0} %broadcast.7)
  %add.10 = f64[10]{0} add(f64[10]{0} %convert.9, f64[10]{0} %multiply.8)
  ROOT %tuple.11 = (f64[10]{0}) tuple(f64[10]{0} %add.10)
}

Expected behavior

value=1.0 in torch.addcdiv should be treated as constant.

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU
  • torch_xla version: master
@JackCaoG JackCaoG self-assigned this Aug 26, 2022
@JackCaoG
Copy link
Collaborator

Thanks for reporting, do you want to take a stab at this one? I should be able to work on it second half of next week if you are busy.

@ymwangg
Copy link
Contributor Author

ymwangg commented Aug 26, 2022

Sure, I'll take a look early next week and may need your help :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants