Skip to content

[Kernel] Add more dtype support for GGUF dequantization #15879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
int64_t n,
std::optional<at::ScalarType> const& dtype);

torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t type, int64_t row);
Expand Down
65 changes: 34 additions & 31 deletions csrc/quantization/gguf/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
dfloat2 v;
dequantize_kernel(vx, ib, iqs, v);

y[iybs + iqs + 0] = v.x;
y[iybs + iqs + y_offset] = v.y;
y[iybs + iqs + 0] = convert_from_half<dst_t>(v.x);
y[iybs + iqs + y_offset] = convert_from_half<dst_t>(v.y);
}

template<typename dst_t>
Expand All @@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t

half dall = __low2half(x[i].dm);
half dmin = __high2half(x[i].dm);
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
y[l+ 0] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))));
y[l+32] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))));
y[l+64] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))));
y[l+96] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))));
}

template<typename dst_t>
Expand Down Expand Up @@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;

for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
for (int l = l0; l < l0+4; ++l) {
y[l] = convert_from_half<dst_t>(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
}
}

static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
Expand Down Expand Up @@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
const half d2 = __hmul(dall, __int2half_rn(sc));
const half m2 = __hmul(dmin, __int2half_rn(m));
for (int l = 0; l < n; ++l) {
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
y[l + 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
y[l +32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
}
}

Expand Down Expand Up @@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));

uint8_t hm = 1 << (2*il);
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
y[ 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
y[ 1] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
hm <<= 1;
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
y[32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
y[33] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
}

template<typename dst_t>
Expand All @@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
const uint8_t qh = x[i].qh[32*ip + il];
const int8_t * sc = x[i].scales + is;

y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
y[ 0] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
y[32] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
y[64] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
y[96] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
}

template<typename dst_t>
Expand All @@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}

template<typename dst_t>
Expand All @@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);

}

Expand All @@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}

template<typename dst_t>
Expand All @@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 4; ++j) {
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
}

Expand All @@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) {
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
}

Expand All @@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
y[j] = d * (q[j] + delta);
}
}

Expand All @@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta));
y[j] = d * (q[j] + delta);
}
}

Expand All @@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = __half2float(x[ib].d);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}

}
Expand All @@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}

Expand Down Expand Up @@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}

static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
template<typename dst_t>
static to_cuda_ggml_t<dst_t> ggml_get_to_cuda(int64_t type) {
switch (type) {
case 2:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
Expand Down
17 changes: 16 additions & 1 deletion csrc/quantization/gguf/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream);
template<typename dst_t>
using to_cuda_ggml_t = void (*)(const void * __restrict__ x, dst_t * __restrict__ y, int k, cudaStream_t stream);
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
typedef void (*load_tiles_cuda_t)(
Expand All @@ -1075,6 +1076,20 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(

// Utility function

template<typename dst_t>
static __device__ __forceinline__ dst_t convert_from_half(half val) {
return val;
}

template<>
__device__ __forceinline__ c10::BFloat16 convert_from_half<c10::BFloat16>(half val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __float2bfloat16(__half2float(val));
#else
return __half2float(val);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}

#if defined(USE_ROCM)

#ifndef __has_builtin
Expand Down
15 changes: 10 additions & 5 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
}

torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int64_t type, int64_t m, int64_t n) {
int64_t type, int64_t m, int64_t n,
std::optional<at::ScalarType> const& dtype) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
auto dtype_ = dtype.value_or(torch::kFloat16);
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
at::Tensor DW = torch::empty({m, n}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);

VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
});

return DW;
}

Expand Down
4 changes: 3 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif

// Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);

// mmvq kernel for GGML.
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def test_ggml_opcheck(quant_type):
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n))
opcheck(torch.ops._C.ggml_dequantize,
(qweight, quant_type, m, n, torch.float16))

x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8,
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_gguf_MoE_tensors(


@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", [torch.half])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
Expand All @@ -78,7 +78,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape)).to(dtype)
quant_type, *list(shape), dtype)

torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)

Expand Down
15 changes: 9 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,12 @@ def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
if hasattr(torch.ops._C, "ggml_dequantize"):

@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
def _ggml_dequantize_fake(
W: torch.Tensor,
quant_type: int,
m: torch.SymInt,
n: torch.SymInt,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@register_fake("_C::ggml_mul_mat_vec_a8")
Expand Down Expand Up @@ -1097,9 +1100,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
dtype: Optional[torch.dtype]) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)


def ggml_mul_mat_vec_a8(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
Expand Down Expand Up @@ -377,7 +377,7 @@ def embedding(self, layer: torch.nn.Module,
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]).to(self.params_dtype)
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)


Expand Down