Skip to content

[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant #12134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ def create_weights(
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale,
requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
Expand Down Expand Up @@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
Expand Down
14 changes: 11 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton
import triton.language as tl

from vllm.platforms import current_platform


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
Expand Down Expand Up @@ -33,11 +35,14 @@ def apply_w8a8_block_fp8_linear(


def input_to_float8(
x: torch.Tensor,
dtype: torch.dtype = torch.float8_e4m3fn
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
Expand Down Expand Up @@ -125,7 +130,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
Expand All @@ -140,6 +145,9 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
Expand Down
Loading