@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
56
56
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32 ;
57
57
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32 ;
58
58
59
- const int64_t i0 = ((int64_t )blockDim .x *blockIdx .x + threadIdx .x )*4 ;
59
+ const int64_t i0 = ((int64_t )blockDim .x *blockIdx .y + threadIdx .x )*4 ;
60
60
61
61
if (i0 >= ne0) {
62
62
return ;
63
63
}
64
64
65
- const int64_t i1 = blockIdx .y ;
65
+ const int64_t i1 = blockIdx .x ;
66
66
const int64_t i2 = blockIdx .z % ne2;
67
67
const int64_t i3 = blockIdx .z / ne2;
68
68
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
75
75
76
76
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
77
77
78
- const int64_t ib0 = blockIdx .z *((int64_t )gridDim .y *gridDim .x *blockDim .x /QK8_1); // first block of channel
79
- const int64_t ib = ib0 + (i0 / (4 *QK8_1))*ne1 + blockIdx .y ; // block index in channel
78
+ const int64_t ib0 = blockIdx .z *((int64_t )gridDim .x *gridDim .y *blockDim .x /QK8_1); // first block of channel
79
+ const int64_t ib = ib0 + (i0 / (4 *QK8_1))*ne1 + blockIdx .x ; // block index in channel
80
80
const int64_t iqs = i0 % (4 *QK8_1); // quant index in block
81
81
82
82
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
166
166
GGML_ASSERT (ne00 % 4 == 0 );
167
167
GGML_ASSERT (ne0 % (4 *QK8_1) == 0 );
168
168
169
- const int64_t block_num_x = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
170
- const dim3 num_blocks (block_num_x, ne1, ne2*ne3);
169
+ // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
170
+ const int64_t block_num_y = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
171
+ const dim3 num_blocks (ne1, block_num_y, ne2*ne3);
171
172
const dim3 block_size (CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1 , 1 );
172
173
switch (mmq_get_q8_1_ds_layout (type_src0)) {
173
174
case MMQ_Q8_1_DS_LAYOUT_D4:
0 commit comments