We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9351f91 commit e11880dCopy full SHA for e11880d
tests/kernels/test_flashmla.py
@@ -124,7 +124,7 @@ def ref_mla():
124
cal_diff(out_flash, out_torch, "out")
125
cal_diff(lse_flash, lse_torch, "lse")
126
127
- t = triton.testing.do_bench(flash_mla, fast_flush=False)
+ t = triton.testing.do_bench(flash_mla)
128
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
129
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
130
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
0 commit comments