Skip to content

fp8 compressed_tensors w8a8 support #3242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile_gaudi
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ RUN cd server && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
Expand Down
2 changes: 2 additions & 0 deletions backends/gaudi/server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
fp8 = "fp8"
compressed_tensors = "compressed-tensors"


class Dtype(str, Enum):
Expand Down Expand Up @@ -109,6 +110,7 @@ def serve(
"gptq",
"awq",
"fp8",
"compressed-tensors",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .loader import CompressedTensorsLoader

__all__ = ["CompressedTensorsLoader"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import Any, Dict, List, Union

from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn

from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)

# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)


class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""

def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)

try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e

if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)

self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)

for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)

def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)

def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)

def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)

def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)

def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)

def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""

loaders: Dict[str, WeightsLoader] = {}

format = quantization_config.format

for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)

loader = self._create_loader_for_group(format, group_name, group)

# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader

return loaders

def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.

input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)

def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""

if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)

# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]
Loading