Skip to content

Commit b5b57e3

Browse files
authored
[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (#12134)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent 54cacf0 commit b5b57e3

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ def create_weights(
247247
def process_weights_after_loading(self, layer: Module) -> None:
248248
# Block quant doesn't need to process weights after loading
249249
if self.block_quant:
250+
if current_platform.is_rocm():
251+
weight, weight_scale, _ = \
252+
normalize_e4m3fn_to_e4m3fnuz(
253+
weight=layer.weight,
254+
weight_scale=layer.weight_scale_inv,
255+
input_scale=layer.input_scale)
256+
layer.weight = Parameter(weight, requires_grad=False)
257+
layer.weight_scale_inv = Parameter(weight_scale,
258+
requires_grad=False)
250259
return
251260
layer.weight = torch.nn.Parameter(layer.weight.data,
252261
requires_grad=False)
@@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
495504
def process_weights_after_loading(self, layer: Module) -> None:
496505
# Block quant doesn't need to process weights after loading
497506
if self.block_quant:
507+
if current_platform.is_rocm():
508+
w13_weight, w13_weight_scale_inv, w13_input_scale = \
509+
normalize_e4m3fn_to_e4m3fnuz(
510+
layer.w13_weight, layer.w13_weight_scale_inv,
511+
layer.w13_input_scale)
512+
w2_weight, w2_weight_scale_inv, w2_input_scale = \
513+
normalize_e4m3fn_to_e4m3fnuz(
514+
layer.w2_weight, layer.w2_weight_scale_inv,
515+
layer.w2_input_scale)
516+
# Reset the parameter
517+
layer.w13_weight = torch.nn.Parameter(w13_weight,
518+
requires_grad=False)
519+
layer.w13_weight_scale_inv = torch.nn.Parameter(
520+
w13_weight_scale_inv, requires_grad=False)
521+
if w13_input_scale is not None:
522+
layer.w13_input_scale = torch.nn.Parameter(
523+
w13_input_scale, requires_grad=False)
524+
layer.w2_weight = torch.nn.Parameter(w2_weight,
525+
requires_grad=False)
526+
layer.w2_weight_scale_inv = torch.nn.Parameter(
527+
w2_weight_scale_inv, requires_grad=False)
528+
if w2_input_scale is not None:
529+
layer.w2_input_scale = torch.nn.Parameter(
530+
w2_input_scale, requires_grad=False)
498531
return
499532
# If checkpoint is fp16, quantize in place.
500533
if not self.quant_config.is_checkpoint_fp8_serialized:

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import triton
66
import triton.language as tl
77

8+
from vllm.platforms import current_platform
9+
810

911
def apply_w8a8_block_fp8_linear(
1012
input: torch.Tensor,
@@ -33,11 +35,14 @@ def apply_w8a8_block_fp8_linear(
3335

3436

3537
def input_to_float8(
36-
x: torch.Tensor,
37-
dtype: torch.dtype = torch.float8_e4m3fn
38+
x: torch.Tensor,
39+
dtype: Optional[torch.dtype] = None
3840
) -> Tuple[torch.Tensor, torch.Tensor]:
3941
"""This function quantizes input values to float8 values "
4042
"with tensor-wise quantization."""
43+
if dtype is None:
44+
dtype = (torch.float8_e4m3fnuz
45+
if current_platform.is_rocm() else torch.float8_e4m3fn)
4146
finfo = torch.finfo(dtype)
4247
min_val, max_val = x.aminmax()
4348
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
@@ -125,7 +130,7 @@ def per_token_group_quant_fp8(
125130
x: torch.Tensor,
126131
group_size: int,
127132
eps: float = 1e-10,
128-
dtype: torch.dtype = torch.float8_e4m3fn,
133+
dtype: Optional[torch.dtype] = None,
129134
) -> Tuple[torch.Tensor, torch.Tensor]:
130135
"""Function to perform per-token-group quantization on an input tensor `x`.
131136
It converts the tensor values into signed float8 values and returns the
@@ -140,6 +145,9 @@ def per_token_group_quant_fp8(
140145
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
141146
scaling factor for quantization.
142147
"""
148+
if dtype is None:
149+
dtype = (torch.float8_e4m3fnuz
150+
if current_platform.is_rocm() else torch.float8_e4m3fn)
143151
assert (x.shape[-1] % group_size == 0), (
144152
f"the last dimension of `x` {x.shape[-1]} must be divisible "
145153
f"by `group_size` {group_size}")

0 commit comments

Comments
 (0)