|
13 | 13 | #include "mmvq.cuh"
|
14 | 14 | #include "mmq.cuh"
|
15 | 15 | #include "moe.cuh"
|
| 16 | +#include "moe_vec.cuh" |
16 | 17 |
|
17 | 18 | // Q8 gemv
|
18 | 19 | template <typename scalar_t>
|
@@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
377 | 378 | return Y;
|
378 | 379 | }
|
379 | 380 |
|
| 381 | +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input |
| 382 | + torch::Tensor W, // expert weights |
| 383 | + torch::Tensor topk_ids, int64_t top_k, |
| 384 | + int64_t type, int64_t row, int64_t tokens) { |
| 385 | + int col = X.sizes()[1]; |
| 386 | + const int padded = (col + 512 - 1) / 512 * 512; |
| 387 | + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); |
| 388 | + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); |
| 389 | + at::Tensor Y = torch::zeros({tokens * top_k, row}, options); |
| 390 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 391 | + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); |
| 392 | + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); |
| 393 | + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { |
| 394 | + quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(), |
| 395 | + (void*)quant_X.data_ptr(), col, tokens, |
| 396 | + stream); |
| 397 | + switch (type) { |
| 398 | + case 2: |
| 399 | + moe_vec_q4_0_q8_1_cuda<scalar_t>( |
| 400 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 401 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 402 | + col, row, quant_X.stride(0), stream); |
| 403 | + break; |
| 404 | + case 3: |
| 405 | + moe_vec_q4_1_q8_1_cuda<scalar_t>( |
| 406 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 407 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 408 | + col, row, quant_X.stride(0), stream); |
| 409 | + break; |
| 410 | + case 6: |
| 411 | + moe_vec_q5_0_q8_1_cuda<scalar_t>( |
| 412 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 413 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 414 | + col, row, quant_X.stride(0), stream); |
| 415 | + break; |
| 416 | + case 7: |
| 417 | + moe_vec_q5_1_q8_1_cuda<scalar_t>( |
| 418 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 419 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 420 | + col, row, quant_X.stride(0), stream); |
| 421 | + break; |
| 422 | + case 8: |
| 423 | + moe_vec_q8_0_q8_1_cuda<scalar_t>( |
| 424 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 425 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 426 | + col, row, quant_X.stride(0), stream); |
| 427 | + break; |
| 428 | + case 10: |
| 429 | + moe_vec_q2_K_q8_1_cuda<scalar_t>( |
| 430 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 431 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 432 | + col, row, quant_X.stride(0), stream); |
| 433 | + break; |
| 434 | + case 11: |
| 435 | + moe_vec_q3_K_q8_1_cuda<scalar_t>( |
| 436 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 437 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 438 | + col, row, quant_X.stride(0), stream); |
| 439 | + break; |
| 440 | + case 12: |
| 441 | + moe_vec_q4_K_q8_1_cuda<scalar_t>( |
| 442 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 443 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 444 | + col, row, quant_X.stride(0), stream); |
| 445 | + break; |
| 446 | + case 13: |
| 447 | + moe_vec_q5_K_q8_1_cuda<scalar_t>( |
| 448 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 449 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 450 | + col, row, quant_X.stride(0), stream); |
| 451 | + break; |
| 452 | + case 14: |
| 453 | + moe_vec_q6_K_q8_1_cuda<scalar_t>( |
| 454 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 455 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 456 | + col, row, quant_X.stride(0), stream); |
| 457 | + break; |
| 458 | + case 16: |
| 459 | + moe_vec_iq2_xxs_q8_1_cuda<scalar_t>( |
| 460 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 461 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 462 | + col, row, quant_X.stride(0), stream); |
| 463 | + break; |
| 464 | + case 17: |
| 465 | + moe_vec_iq2_xs_q8_1_cuda<scalar_t>( |
| 466 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 467 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 468 | + col, row, quant_X.stride(0), stream); |
| 469 | + break; |
| 470 | + case 18: |
| 471 | + moe_vec_iq3_xxs_q8_1_cuda<scalar_t>( |
| 472 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 473 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 474 | + col, row, quant_X.stride(0), stream); |
| 475 | + break; |
| 476 | + case 19: |
| 477 | + moe_vec_iq1_s_q8_1_cuda<scalar_t>( |
| 478 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 479 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 480 | + col, row, quant_X.stride(0), stream); |
| 481 | + break; |
| 482 | + case 20: |
| 483 | + moe_vec_iq4_nl_q8_1_cuda<scalar_t>( |
| 484 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 485 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 486 | + col, row, quant_X.stride(0), stream); |
| 487 | + break; |
| 488 | + case 21: |
| 489 | + moe_vec_iq3_s_q8_1_cuda<scalar_t>( |
| 490 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 491 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 492 | + col, row, quant_X.stride(0), stream); |
| 493 | + break; |
| 494 | + case 22: |
| 495 | + moe_vec_iq2_s_q8_1_cuda<scalar_t>( |
| 496 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 497 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 498 | + col, row, quant_X.stride(0), stream); |
| 499 | + break; |
| 500 | + case 23: |
| 501 | + moe_vec_iq4_xs_q8_1_cuda<scalar_t>( |
| 502 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 503 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 504 | + col, row, quant_X.stride(0), stream); |
| 505 | + break; |
| 506 | + case 29: |
| 507 | + moe_vec_iq1_m_q8_1_cuda<scalar_t>( |
| 508 | + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), |
| 509 | + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, |
| 510 | + col, row, quant_X.stride(0), stream); |
| 511 | + break; |
| 512 | + } |
| 513 | + }); |
| 514 | + return Y; |
| 515 | +} |
| 516 | + |
380 | 517 | int64_t ggml_moe_get_block_size(int64_t type) {
|
381 | 518 | switch (type) {
|
382 | 519 | case 2:
|
|
0 commit comments