Skip to content

Commit 7b79dad

Browse files
rasmithtjtanaa
authored andcommitted
[AMD][Quantization] Add TritonScaledMMLinearKernel since int8 is broken for AMD (vllm-project#12282)
Signed-off-by: Randall Smith <[email protected]>
1 parent d57c673 commit 7b79dad

File tree

3 files changed

+58
-5
lines changed

3 files changed

+58
-5
lines changed

tests/kernels/test_triton_scaled_mm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ def get_8bit_types():
3939
return types
4040

4141

42+
# This test is to check regressions for int8 support on ROCm.
43+
@pytest.mark.parametrize("model_path", [
44+
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
45+
])
46+
@pytest.mark.parametrize("max_tokens", [32])
47+
@pytest.mark.parametrize("num_logprobs", [10])
48+
@pytest.mark.skipif(not current_platform.is_rocm(),
49+
reason="Should only run on ROCm")
50+
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
51+
max_tokens, num_logprobs):
52+
dtype = "bfloat16"
53+
54+
with vllm_runner(model_path, dtype=dtype) as vllm_model:
55+
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
56+
num_logprobs)
57+
58+
4259
@pytest.mark.parametrize("M", [1, 33, 64, 512])
4360
@pytest.mark.parametrize("N", [256, 971, 20486])
4461
@pytest.mark.parametrize("K", [128, 496, 1024])

vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
CutlassScaledMMLinearKernel)
66
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
77
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
8-
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
9-
# TritonScaledMMLinear)
8+
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
9+
TritonScaledMMLinearKernel)
1010
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
1111
XLAScaledMMLinearKernel)
1212
from vllm.platforms import PlatformEnum, current_platform
@@ -15,9 +15,7 @@
1515
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
1616
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
1717
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
18-
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
19-
# incorrectly attempt to run AZP models if prompted to.
20-
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
18+
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
2119
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
2220
}
2321

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
5+
from vllm.platforms import current_platform
6+
7+
from .cutlass import CutlassScaledMMLinearKernel
8+
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
9+
10+
11+
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
12+
13+
@classmethod
14+
def get_min_capability(cls) -> int:
15+
return 75
16+
17+
@classmethod
18+
def can_implement(
19+
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
20+
if current_platform.is_cpu():
21+
return (
22+
False,
23+
"TritonScaledMMLinearKernel requires Triton which is not " +
24+
"currently supported on CPU.")
25+
if not c.input_symmetric:
26+
return (False,
27+
"TritonScaledMMLinearKernel only supports symmetric " +
28+
"quantization.")
29+
return True, None
30+
31+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
32+
super().process_weights_after_loading(layer)
33+
34+
def apply_weights(self,
35+
layer: torch.nn.Module,
36+
x: torch.Tensor,
37+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
38+
return super().apply_weights(layer, x, bias)

0 commit comments

Comments
 (0)