Skip to content

Commit 22ad629

Browse files
authored
[bug fix] dequantize 4bit (#19793)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 860eb76 commit 22ad629

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ void Dequantize4BitsKernelReOrder(
4141
T* output_i = output + out_y * out_cols + out_x;
4242
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
4343
const int remain_x = std::min(8, out_cols - out_x);
44+
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
4445
for (int i = 0; i < remain_x; i++) {
45-
int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx;
46+
int32_t rid = reorder_idx ? reorder_idx_with_off[i] : kb_idx;
4647
T scale = *(scale_data + n_idx * scales_shape_x + rid);
4748
float zp_f = 8;
4849
if (zero_points) {

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace cuda {
2323

2424
__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) {
2525
half2 scale_half2 = {scale, scale};
26-
half zp_adjust = -scale * __short2half_rn(zp);
26+
half zp_adjust = -scale * zp;
2727
half2 zp_adjust2 = {zp_adjust, zp_adjust};
2828

2929
alignas(16) half2 results[4];
@@ -83,8 +83,9 @@ __global__ void Dequantize4BitsKernelReOrder(
8383
int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
8484
T* output_i = output + element_offset;
8585
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
86+
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1));
8687
for (int i = 0; i < 8; i++) {
87-
int32_t rid = reorder_idx[kb_idx * block_size + i];
88+
int32_t rid = reorder_idx_with_off[i];
8889
T scale = *(scale_data + n_idx * scales_shape_x + rid);
8990
uint8_t zp = 8;
9091
if (zero_points) {
@@ -157,7 +158,7 @@ Status Dequantize4Bits(
157158
int groups_per_K = k / block_size;
158159
int total_groups = n * groups_per_K; // total elemenets in quant_data
159160
int groups_per_grid = static_cast<int>(CeilDiv(total_groups, groups_per_threadblock));
160-
if (!reorder_idx) {
161+
if (!reorder_idx || std::is_same_v<ZeroT, T>) {
161162
Dequantize4BitsKernel<T, ZeroT><<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
162163
output,
163164
quant_data,

0 commit comments

Comments
 (0)