Skip to content

Commit d02c568

Browse files
committed
Merge fixes
1 parent 5d249fe commit d02c568

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
150150
}
151151

152152
// Reduce the max and min values across the block
153-
max_val = blockReduceMax(max_val);
153+
using BlockReduce = cub::BlockReduce<float, 1024>;
154+
__shared__ typename BlockReduce::TempStorage reduceStorage;
155+
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
154156
__syncthreads(); // Make sure min doesn't mess with max shared memory
155-
min_val = blockReduceMin(min_val);
157+
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
156158

157159
__shared__ scale_type scale_sh;
158160
__shared__ azp_type azp_sh;

tests/kernels/test_int8_quant.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
3939
# reference
4040
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
4141
# kernel
42-
ops_out, ops_scales = scaled_int8_quant(x)
42+
ops_out, ops_scales, _ = scaled_int8_quant(x)
4343

4444
torch.testing.assert_close(ops_scales, ref_scales)
4545
torch.testing.assert_close(
@@ -103,11 +103,10 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
103103

104104
out1 = (x / scale).round().clamp(int8_traits.min,
105105
int8_traits.max).to(torch.int8)
106-
out2, _ = scaled_int8_quant(x, scale)
106+
out2, _, _ = scaled_int8_quant(x, scale)
107107

108-
torch.testing.assert_close(
109-
out1, out2, atol=1,
110-
rtol=0.0) # big atol to account for rounding errors
108+
# big atol to account for rounding errors
109+
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
111110

112111

113112
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -135,4 +134,6 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
135134

136135
torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument,
137136
azp_argument)
138-
torch.testing.assert_close(out1, out2, atol=1) # atol for rounding
137+
138+
# big atol to account for rounding errors
139+
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)

vllm/_custom_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,9 @@ def scaled_int8_quant(
438438
output = torch.empty_like(input, dtype=torch.int8)
439439
if scale is not None:
440440
# static-per-tensor quantization.
441-
assert symmetric == azp is None, "azp must be only be provided for asymmetric quantization."
441+
assert symmetric == (
442+
azp is
443+
None), "azp must only be provided for asymmetric quantization."
442444
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
443445
return output, scale, None
444446

0 commit comments

Comments
 (0)