Skip to content

Commit 4696d56

Browse files
CUDA: fix crash on large batch size for quant. MoE (#13537)
1 parent b7d2672 commit 4696d56

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
122122
const int64_t s13 = src1->nb[3] / ts_src1;
123123
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
124124
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
125+
CUDA_CHECK(cudaGetLastError());
125126
}
126127

127128
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
205206
const int64_t s13 = src1->nb[2] / ts_src1;
206207
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
207208
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
209+
CUDA_CHECK(cudaGetLastError());
208210
}
209211

210212
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));

ggml/src/ggml-cuda/quantize.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
5656
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
5757
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
5858

59-
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
59+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
6060

6161
if (i0 >= ne0) {
6262
return;
6363
}
6464

65-
const int64_t i1 = blockIdx.y;
65+
const int64_t i1 = blockIdx.x;
6666
const int64_t i2 = blockIdx.z % ne2;
6767
const int64_t i3 = blockIdx.z / ne2;
6868

@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
7575

7676
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
7777

78-
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
79-
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
78+
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
79+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
8080
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
8181

8282
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
166166
GGML_ASSERT(ne00 % 4 == 0);
167167
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
168168

169-
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
170-
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
169+
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
170+
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
171+
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
171172
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
172173
switch (mmq_get_q8_1_ds_layout(type_src0)) {
173174
case MMQ_Q8_1_DS_LAYOUT_D4:

0 commit comments

Comments
 (0)