Skip to content

Commit c4e6079

Browse files
rasmithIsotr0py
authored andcommitted
[Kernel][Triton][AMD] Use block size heuristic for avg 2.8x speedup for int8 models (vllm-project#11698)
Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 6937cdd commit c4e6079

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
128128
bias: Optional[torch.Tensor] = None,
129129
block_size_m: int = 32,
130130
block_size_n: int = 32,
131-
block_size_k: int = 32) -> torch.Tensor:
131+
block_size_k: int = 32,
132+
use_heuristic=True) -> torch.Tensor:
132133
M, K = input.shape
133134
N = weight.shape[1]
134135

@@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,
152153

153154
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
154155

156+
if use_heuristic:
157+
is_small_N = N < 8192
158+
next_power_of_2_M = max(32, triton.next_power_of_2(M))
159+
if next_power_of_2_M <= 32:
160+
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
161+
elif next_power_of_2_M <= 64:
162+
tile_shape = (64, 64, 256)
163+
elif next_power_of_2_M <= 128:
164+
tile_shape = (64, 128, 128)
165+
else:
166+
tile_shape = (128, 128, 128)
167+
168+
block_size_m, block_size_n, block_size_k = tile_shape
169+
155170
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
156171
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
157172

0 commit comments

Comments
 (0)