13
13
template <typename scalar_t , int bit, int GROUPS>
14
14
__global__ void moe_wna16_gemm_kernel (
15
15
const scalar_t * __restrict__ input, scalar_t * __restrict__ output,
16
-
17
16
const uint32_t * __restrict__ qweight, const scalar_t * __restrict__ scales,
18
17
const uint32_t * __restrict__ qzeros,
19
18
@@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
54
53
if (token_index / top_k >= size_m) break ;
55
54
56
55
num_valid_tokens = m + 1 ;
57
- if (blockIdx .z == 0 && offset_n < size_n)
58
- output[token_index * size_n + offset_n] = Dtype::int2num (0 );
59
56
60
57
if (expert_id != -1 ) {
61
58
int k_per_thread = DIVIDE (BLOCK_SIZE_K, BLOCK_SIZE_N);
@@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
284
281
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
285
282
int64_t BLOCK_SIZE_K, int64_t bit) {
286
283
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
287
- auto options =
288
- torch::TensorOptions ().dtype (input.dtype ()).device (input.device ());
284
+ output.zero_ ();
289
285
290
286
const int num_experts = b_qweight.size (0 );
291
287
const int size_m = input.size (0 );
@@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
302
298
const uint32_t * b_qzeros_ptr;
303
299
if (b_qzeros.has_value ())
304
300
b_qzeros_ptr = (const uint32_t *)b_qzeros.value ().data_ptr <uint8_t >();
305
- const float * topk_weights_ptr;
301
+ const float * topk_weights_ptr = nullptr ;
306
302
if (topk_weights.has_value ())
307
- topk_weights_ptr = (const float *)topk_weights.value ().data_ptr ();
303
+ topk_weights_ptr = (const float *)topk_weights.value ().data_ptr < float > ();
308
304
309
305
int groups_per_block_row = BLOCK_SIZE_K / group_size;
310
306
TORCH_CHECK (bit == 4 || bit == 8 , " bit must be 4 or 8" );
0 commit comments