@@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
209
209
c = accumulator .to (c_ptr .type .element_ty )
210
210
offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
211
211
offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
212
- c_ptrs = c_ptr + N * offs_cm [:, None ] + offs_cn [None , :]
212
+ c_ptrs = c_ptr + pid_z * N * M + N * offs_cm [:, None ] + offs_cn [None , :]
213
213
c_mask = (offs_cm [:, None ] < M ) & (offs_cn [None , :] < N )
214
- if SPLIT_K == 1 :
215
- tl .store (c_ptrs , c , mask = c_mask )
216
- else :
217
- tl .atomic_add (c_ptrs , c , mask = c_mask )
214
+ tl .store (c_ptrs , c , mask = c_mask )
218
215
219
216
220
217
# qweights - [K , M // 8], int32
@@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor,
295
292
split_k_iters ,
296
293
)
297
294
298
- result = torch .zeros ((M , N ), dtype = scales .dtype , device = input .device )
295
+ result = torch .zeros ((split_k_iters , M , N ),
296
+ dtype = scales .dtype ,
297
+ device = input .device )
299
298
300
299
# A = input, B = qweight, C = result
301
300
# A = M x K, B = K x N, C = M x N
@@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor,
313
312
BLOCK_SIZE_K = block_size_k ,
314
313
SPLIT_K = split_k_iters )
315
314
315
+ result = result .sum (0 )
316
+
316
317
return result
0 commit comments