Skip to content

Commit ec4aaad

Browse files
authored
[Kernel][Triton][AMD] Remove tl.atomic_add from awq_gemm_kernel, 2-5x speedup MI300, minor improvement for MI250 (#8646)
1 parent 4dfdf43 commit ec4aaad

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

vllm/model_executor/layers/quantization/awq_triton.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
209209
c = accumulator.to(c_ptr.type.element_ty)
210210
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
211211
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
212-
c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]
212+
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
213213
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
214-
if SPLIT_K == 1:
215-
tl.store(c_ptrs, c, mask=c_mask)
216-
else:
217-
tl.atomic_add(c_ptrs, c, mask=c_mask)
214+
tl.store(c_ptrs, c, mask=c_mask)
218215

219216

220217
# qweights - [K , M // 8], int32
@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
295292
split_k_iters,
296293
)
297294

298-
result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)
295+
result = torch.zeros((split_k_iters, M, N),
296+
dtype=scales.dtype,
297+
device=input.device)
299298

300299
# A = input, B = qweight, C = result
301300
# A = M x K, B = K x N, C = M x N
@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
313312
BLOCK_SIZE_K=block_size_k,
314313
SPLIT_K=split_k_iters)
315314

315+
result = result.sum(0)
316+
316317
return result

0 commit comments

Comments
 (0)