@@ -1721,6 +1721,32 @@ static __global__ void k_get_rows(
1721
1721
dst_row[iybs + iqs + y_offset] = v.y ;
1722
1722
}
1723
1723
1724
+ template <typename src0_t , typename dst_t >
1725
+ static __global__ void k_get_rows_float (
1726
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
1727
+ int64_t ne00, /* int64_t ne01, int64_t ne02, int64_t ne03,*/
1728
+ /* int64_t ne10, int64_t ne11,*/ int64_t ne12, /* int64_t ne13,*/
1729
+ /* size_t s0,*/ size_t s1, size_t s2, size_t s3,
1730
+ /* size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1731
+ size_t s10, size_t s11, size_t s12/* , size_t s13*/ ) {
1732
+
1733
+ const int i00 = blockIdx .x *blockDim .x + threadIdx .x ;
1734
+ const int i10 = blockDim .y *blockIdx .y + threadIdx .y ;
1735
+ const int i11 = (blockIdx .z *blockDim .z + threadIdx .z )/ne12;
1736
+ const int i12 = (blockIdx .z *blockDim .z + threadIdx .z )%ne12;
1737
+
1738
+ if (i00 >= ne00) {
1739
+ return ;
1740
+ }
1741
+
1742
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1743
+
1744
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1745
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1746
+
1747
+ dst_row[i00] = src0_row[i00];
1748
+ }
1749
+
1724
1750
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
1725
1751
static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1726
1752
const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
@@ -5083,6 +5109,8 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
5083
5109
const size_t s12 = nb12 / ggml_element_size (src1);
5084
5110
// const size_t s13 = nb13 / ggml_element_size(src1);
5085
5111
5112
+ GGML_ASSERT (ne00 % 2 == 0 );
5113
+
5086
5114
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
5087
5115
src0_dd, src1_dd, dst_dd,
5088
5116
ne00, /* ne01, ne02, ne03,*/
@@ -5094,6 +5122,38 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
5094
5122
(void ) dst;
5095
5123
}
5096
5124
5125
+ template <typename src0_t >
5126
+ static void get_rows_cuda_float (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
5127
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
5128
+
5129
+ GGML_TENSOR_BINARY_OP_LOCALS
5130
+
5131
+ const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
5132
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
5133
+ const dim3 block_nums (block_num_x, ne10, ne11*ne12);
5134
+
5135
+ // strides in elements
5136
+ // const size_t s0 = nb0 / ggml_element_size(dst);
5137
+ const size_t s1 = nb1 / ggml_element_size (dst);
5138
+ const size_t s2 = nb2 / ggml_element_size (dst);
5139
+ const size_t s3 = nb3 / ggml_element_size (dst);
5140
+
5141
+ const size_t s10 = nb10 / ggml_element_size (src1);
5142
+ const size_t s11 = nb11 / ggml_element_size (src1);
5143
+ const size_t s12 = nb12 / ggml_element_size (src1);
5144
+ // const size_t s13 = nb13 / ggml_element_size(src1);
5145
+
5146
+ k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
5147
+ src0_dd, src1_dd, dst_dd,
5148
+ ne00, /* ne01, ne02, ne03,*/
5149
+ /* ne10, ne11,*/ ne12, /* ne13,*/
5150
+ /* s0,*/ s1, s2, s3,
5151
+ /* nb00,*/ nb01, nb02, nb03,
5152
+ s10, s11, s12/* , s13*/ );
5153
+
5154
+ (void ) dst;
5155
+ }
5156
+
5097
5157
template <float (*bin_op)(const float , const float )>
5098
5158
struct bin_bcast_cuda {
5099
5159
template <typename src0_t , typename src1_t , typename dst_t >
@@ -6491,10 +6551,10 @@ static void ggml_cuda_op_get_rows(
6491
6551
6492
6552
switch (src0->type ) {
6493
6553
case GGML_TYPE_F16:
6494
- get_rows_cuda< 1 , 1 , convert_f16> (src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6554
+ get_rows_cuda_float (src0, src1, dst, ( const half *) src0_d, src1_i32, dst_d, stream);
6495
6555
break ;
6496
6556
case GGML_TYPE_F32:
6497
- get_rows_cuda< 1 , 1 , convert_f32> (src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6557
+ get_rows_cuda_float (src0, src1, dst, src0_d, src1_i32, dst_d, stream);
6498
6558
break ;
6499
6559
case GGML_TYPE_Q4_0:
6500
6560
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
0 commit comments