Skip to content

Commit b402c9d

Browse files
committed
Added asymmetric integration to linear layers
1 parent e05068c commit b402c9d

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
122122
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
123123
is_tensor = (weight_strategy and input_quant.strategy
124124
== QuantizationStrategy.TENSOR.value)
125-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
125+
is_symmetric = weight_quant.symmetric
126126
is_static = not weight_quant.dynamic and not input_quant.dynamic
127127

128128
return is_8_bits and is_tensor and is_symmetric and is_static
@@ -135,7 +135,7 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
135135
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
136136
is_token = (weight_strategy and input_quant.strategy
137137
== QuantizationStrategy.TOKEN.value)
138-
is_symmetric = weight_quant.symmetric and input_quant.symmetric
138+
is_symmetric = weight_quant.symmetric
139139
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
140140

141141
return is_8_bits and is_token and is_symmetric and is_dynamic
@@ -253,12 +253,14 @@ def _get_scheme_from_parts(
253253
if self._is_static_tensor_w8a8(weight_quant, input_quant):
254254
return CompressedTensorsW8A8Int8(
255255
strategy=weight_quant.strategy,
256-
is_static_input_scheme=True)
256+
is_static_input_scheme=True,
257+
input_symmetric=input_quant.symmetric)
257258

258259
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
259260
return CompressedTensorsW8A8Int8(
260261
strategy=weight_quant.strategy,
261-
is_static_input_scheme=False)
262+
is_static_input_scheme=False,
263+
input_symmetric=input_quant.symmetric)
262264

263265
raise NotImplementedError(
264266
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
1717

18-
def __init__(self, strategy: str, is_static_input_scheme: bool):
18+
def __init__(self, strategy: str, is_static_input_scheme: bool,
19+
input_symmetric: bool):
1920
self.strategy = strategy
2021
self.is_static_input_scheme = is_static_input_scheme
22+
self.input_symmetric = input_symmetric
2123

2224
@classmethod
2325
def get_min_capability(cls) -> int:
@@ -44,8 +46,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
4446
if self.is_static_input_scheme:
4547
layer.input_scale = Parameter(layer.input_scale.max(),
4648
requires_grad=False)
49+
if not self.input_symmetric:
50+
layer.input_zero_point = Parameter(layer.input_zero_point,
51+
requires_grad=False)
52+
else:
53+
layer.input_zero_point = None
4754
else:
4855
layer.input_scale = None
56+
layer.input_zero_point = None
57+
58+
if not self.input_symmetric:
59+
layer.azp_adj = layer.weight.sum(dim=0,
60+
keepdim=True,
61+
dtype=torch.int32)
62+
else:
63+
layer.azp_adj = None
4964

5065
def create_weights(self, layer: torch.nn.Module,
5166
output_partition_sizes: List[int],
@@ -83,11 +98,18 @@ def create_weights(self, layer: torch.nn.Module,
8398
output_partition_sizes, **layer_kwargs)
8499
layer.register_parameter("input_scale", input_scale)
85100

101+
if not self.input_symmetric:
102+
raise NotImplementedError(
103+
"static input asymmetric quantization not supported yet")
104+
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int8))
105+
layer.register_parameter("input_zero_point", input_zero_point)
106+
86107
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
87108
bias: Optional[torch.Tensor]) -> torch.Tensor:
88-
89109
return apply_int8_linear(input=x,
90110
weight=layer.weight,
91111
weight_scale=layer.weight_scale,
92112
input_scale=layer.input_scale,
113+
input_zero_point=layer.input_zero_point,
114+
azp_adj=layer.azp_adj,
93115
bias=bias)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,28 @@ def apply_int8_linear(
194194
weight: torch.Tensor,
195195
weight_scale: torch.Tensor,
196196
input_scale: Optional[torch.Tensor] = None,
197+
input_zero_point: Optional[torch.Tensor] = None,
198+
azp_adj: Optional[torch.Tensor] = None,
197199
bias: Optional[torch.Tensor] = None,
198200
):
199201
# ops.scaled_int8_quant supports both dynamic and static quant.
200202
# * dynamic, layer.input_scale is None and x_scale computed from x.
201203
# * static, layer.input_scale is scalar and x_scale is input_scale.
202-
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
203-
204+
symmetric = azp_adj is None
205+
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
206+
input_scale,
207+
input_zero_point,
208+
symmetric=symmetric)
209+
210+
if x_zp is not None:
211+
return ops.cutlass_scaled_mm_azp(x_q,
212+
weight,
213+
scale_a=x_scale,
214+
scale_b=weight_scale,
215+
out_dtype=input.dtype,
216+
azp_adj=azp_adj,
217+
azp=x_zp,
218+
bias=bias)
204219
return ops.cutlass_scaled_mm(x_q,
205220
weight,
206221
scale_a=x_scale,

0 commit comments

Comments
 (0)