Skip to content

Commit 1d997bb

Browse files
committed
Added asymmetric integration to linear layers
1 parent d02c568 commit 1d997bb

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
129129
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
130130
is_tensor = (weight_strategy and input_quant.strategy
131131
== QuantizationStrategy.TENSOR.value)
132-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
132+
is_symmetric = weight_quant.symmetric
133133
is_static = not weight_quant.dynamic and not input_quant.dynamic
134134

135135
return is_8_bits and is_tensor and is_symmetric and is_static
@@ -142,7 +142,7 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
142142
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
143143
is_token = (weight_strategy and input_quant.strategy
144144
== QuantizationStrategy.TOKEN.value)
145-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
145+
is_symmetric = weight_quant.symmetric
146146
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
147147

148148
return is_8_bits and is_token and is_symmetric and is_dynamic
@@ -255,12 +255,14 @@ def _get_scheme_from_parts(
255255
if self._is_static_tensor_w8a8(weight_quant, input_quant):
256256
return CompressedTensorsW8A8Int8(
257257
strategy=weight_quant.strategy,
258-
is_static_input_scheme=True)
258+
is_static_input_scheme=True,
259+
input_symmetric=input_quant.symmetric)
259260

260261
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
261262
return CompressedTensorsW8A8Int8(
262263
strategy=weight_quant.strategy,
263-
is_static_input_scheme=False)
264+
is_static_input_scheme=False,
265+
input_symmetric=input_quant.symmetric)
264266

265267
raise NotImplementedError(
266268
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
1919

20-
def __init__(self, strategy: str, is_static_input_scheme: bool):
20+
def __init__(self, strategy: str, is_static_input_scheme: bool,
21+
input_symmetric: bool):
2122
self.strategy = strategy
2223
self.is_static_input_scheme = is_static_input_scheme
24+
self.input_symmetric = input_symmetric
2325

2426
@classmethod
2527
def get_min_capability(cls) -> int:
@@ -48,8 +50,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4850
if self.is_static_input_scheme:
4951
layer.input_scale = Parameter(layer.input_scale.max(),
5052
requires_grad=False)
53+
if not self.input_symmetric:
54+
layer.input_zero_point = Parameter(layer.input_zero_point,
55+
requires_grad=False)
56+
else:
57+
layer.input_zero_point = None
5158
else:
5259
layer.input_scale = None
60+
layer.input_zero_point = None
61+
62+
if not self.input_symmetric:
63+
layer.azp_adj = layer.weight.sum(dim=0,
64+
keepdim=True,
65+
dtype=torch.int32)
66+
else:
67+
layer.azp_adj = None
5368

5469
def create_weights(self, layer: torch.nn.Module,
5570
output_partition_sizes: List[int],
@@ -90,11 +105,18 @@ def create_weights(self, layer: torch.nn.Module,
90105
weight_loader=weight_loader)
91106
layer.register_parameter("input_scale", input_scale)
92107

108+
if not self.input_symmetric:
109+
raise NotImplementedError(
110+
"static input asymmetric quantization not supported yet")
111+
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int8))
112+
layer.register_parameter("input_zero_point", input_zero_point)
113+
93114
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
94115
bias: Optional[torch.Tensor]) -> torch.Tensor:
95-
96116
return apply_int8_linear(input=x,
97117
weight=layer.weight,
98118
weight_scale=layer.weight_scale,
99119
input_scale=layer.input_scale,
120+
input_zero_point=layer.input_zero_point,
121+
azp_adj=layer.azp_adj,
100122
bias=bias)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,28 @@ def apply_int8_linear(
210210
weight: torch.Tensor,
211211
weight_scale: torch.Tensor,
212212
input_scale: Optional[torch.Tensor] = None,
213+
input_zero_point: Optional[torch.Tensor] = None,
214+
azp_adj: Optional[torch.Tensor] = None,
213215
bias: Optional[torch.Tensor] = None,
214216
):
215217
# ops.scaled_int8_quant supports both dynamic and static quant.
216218
# * dynamic, layer.input_scale is None and x_scale computed from x.
217219
# * static, layer.input_scale is scalar and x_scale is input_scale.
218-
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
219-
220+
symmetric = azp_adj is None
221+
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
222+
input_scale,
223+
input_zero_point,
224+
symmetric=symmetric)
225+
226+
if x_zp is not None:
227+
return ops.cutlass_scaled_mm_azp(x_q,
228+
weight,
229+
scale_a=x_scale,
230+
scale_b=weight_scale,
231+
out_dtype=input.dtype,
232+
azp_adj=azp_adj,
233+
azp=x_zp,
234+
bias=bias)
220235
return ops.cutlass_scaled_mm(x_q,
221236
weight,
222237
scale_a=x_scale,

0 commit comments

Comments
 (0)