Skip to content

Commit b7c9766

Browse files
SzymonOzogIsotr0py
authored and
Mu Huai
committed
[Kernel] GGUF MoeVec kernel (vllm-project#16780)
Signed-off-by: SzymonOzog <[email protected]> Signed-off-by: SzymonOzog <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 293d563 commit b7c9766

File tree

8 files changed

+544
-16
lines changed

8 files changed

+544
-16
lines changed

csrc/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
178178
torch::Tensor num_tokens_post_padded, int64_t type,
179179
int64_t row, int64_t top_k, int64_t tokens);
180180

181+
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
182+
torch::Tensor topk_ids, int64_t top_k,
183+
int64_t type, int64_t row, int64_t tokens);
184+
181185
int64_t ggml_moe_get_block_size(int64_t type);
182186

183187
#ifndef USE_ROCM

csrc/quantization/gguf/gguf_kernel.cu

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mmvq.cuh"
1414
#include "mmq.cuh"
1515
#include "moe.cuh"
16+
#include "moe_vec.cuh"
1617

1718
// Q8 gemv
1819
template <typename scalar_t>
@@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
377378
return Y;
378379
}
379380

381+
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
382+
torch::Tensor W, // expert weights
383+
torch::Tensor topk_ids, int64_t top_k,
384+
int64_t type, int64_t row, int64_t tokens) {
385+
int col = X.sizes()[1];
386+
const int padded = (col + 512 - 1) / 512 * 512;
387+
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
388+
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
389+
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
390+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
391+
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
392+
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
393+
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
394+
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
395+
(void*)quant_X.data_ptr(), col, tokens,
396+
stream);
397+
switch (type) {
398+
case 2:
399+
moe_vec_q4_0_q8_1_cuda<scalar_t>(
400+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
401+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
402+
col, row, quant_X.stride(0), stream);
403+
break;
404+
case 3:
405+
moe_vec_q4_1_q8_1_cuda<scalar_t>(
406+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
407+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
408+
col, row, quant_X.stride(0), stream);
409+
break;
410+
case 6:
411+
moe_vec_q5_0_q8_1_cuda<scalar_t>(
412+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
413+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
414+
col, row, quant_X.stride(0), stream);
415+
break;
416+
case 7:
417+
moe_vec_q5_1_q8_1_cuda<scalar_t>(
418+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
419+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
420+
col, row, quant_X.stride(0), stream);
421+
break;
422+
case 8:
423+
moe_vec_q8_0_q8_1_cuda<scalar_t>(
424+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
425+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
426+
col, row, quant_X.stride(0), stream);
427+
break;
428+
case 10:
429+
moe_vec_q2_K_q8_1_cuda<scalar_t>(
430+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
431+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
432+
col, row, quant_X.stride(0), stream);
433+
break;
434+
case 11:
435+
moe_vec_q3_K_q8_1_cuda<scalar_t>(
436+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
437+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
438+
col, row, quant_X.stride(0), stream);
439+
break;
440+
case 12:
441+
moe_vec_q4_K_q8_1_cuda<scalar_t>(
442+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
443+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
444+
col, row, quant_X.stride(0), stream);
445+
break;
446+
case 13:
447+
moe_vec_q5_K_q8_1_cuda<scalar_t>(
448+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
449+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
450+
col, row, quant_X.stride(0), stream);
451+
break;
452+
case 14:
453+
moe_vec_q6_K_q8_1_cuda<scalar_t>(
454+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
455+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
456+
col, row, quant_X.stride(0), stream);
457+
break;
458+
case 16:
459+
moe_vec_iq2_xxs_q8_1_cuda<scalar_t>(
460+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
461+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
462+
col, row, quant_X.stride(0), stream);
463+
break;
464+
case 17:
465+
moe_vec_iq2_xs_q8_1_cuda<scalar_t>(
466+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
467+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
468+
col, row, quant_X.stride(0), stream);
469+
break;
470+
case 18:
471+
moe_vec_iq3_xxs_q8_1_cuda<scalar_t>(
472+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
473+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
474+
col, row, quant_X.stride(0), stream);
475+
break;
476+
case 19:
477+
moe_vec_iq1_s_q8_1_cuda<scalar_t>(
478+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
479+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
480+
col, row, quant_X.stride(0), stream);
481+
break;
482+
case 20:
483+
moe_vec_iq4_nl_q8_1_cuda<scalar_t>(
484+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
485+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
486+
col, row, quant_X.stride(0), stream);
487+
break;
488+
case 21:
489+
moe_vec_iq3_s_q8_1_cuda<scalar_t>(
490+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
491+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
492+
col, row, quant_X.stride(0), stream);
493+
break;
494+
case 22:
495+
moe_vec_iq2_s_q8_1_cuda<scalar_t>(
496+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
497+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
498+
col, row, quant_X.stride(0), stream);
499+
break;
500+
case 23:
501+
moe_vec_iq4_xs_q8_1_cuda<scalar_t>(
502+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
503+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
504+
col, row, quant_X.stride(0), stream);
505+
break;
506+
case 29:
507+
moe_vec_iq1_m_q8_1_cuda<scalar_t>(
508+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
509+
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
510+
col, row, quant_X.stride(0), stream);
511+
break;
512+
}
513+
});
514+
return Y;
515+
}
516+
380517
int64_t ggml_moe_get_block_size(int64_t type) {
381518
switch (type) {
382519
case 2:

0 commit comments

Comments
 (0)