Skip to content

Commit 8c01c9b

Browse files
woachkggerganov
authored andcommitted
kompute : implement op_getrows_f32 (llama/6403)
op_getrows_f32 is required since ggml-org/llama.cpp#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
1 parent d1123d7 commit 8c01c9b

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

ggml-kompute.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "shaderop_mul_mat_q4_1.h"
2323
#include "shaderop_mul_mat_q6_k.h"
2424
#include "shaderop_mul_mat_mat_f32.h"
25+
#include "shaderop_getrows_f32.h"
2526
#include "shaderop_getrows_f16.h"
2627
#include "shaderop_getrows_q4_0.h"
2728
#include "shaderop_getrows_q4_1.h"
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
11461147
seq.record<kp::OpAlgoDispatch>(s_algo);
11471148
}
11481149

1150+
template <typename... Args>
1151+
static void ggml_vk_get_rows_f32(Args&&... args) {
1152+
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1153+
kp::shader_data::op_getrows_f32_comp_spv_len);
1154+
1155+
ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1156+
}
1157+
11491158
template <typename... Args>
11501159
static void ggml_vk_get_rows_f16(Args&&... args) {
11511160
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
13711380
return op->ne[3] == 1;
13721381
case GGML_OP_GET_ROWS:
13731382
switch (op->src[0]->type) {
1383+
case GGML_TYPE_F32:
13741384
case GGML_TYPE_F16:
13751385
case GGML_TYPE_Q4_0:
13761386
case GGML_TYPE_Q4_1:
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16611671
} break;
16621672
case GGML_OP_GET_ROWS:
16631673
{
1664-
if (src0t == GGML_TYPE_F16) {
1674+
if (src0t == GGML_TYPE_F32) {
1675+
ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1676+
} else if (src0t == GGML_TYPE_F16) {
16651677
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
16661678
} else if (src0t == GGML_TYPE_Q4_0) {
16671679
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));

0 commit comments

Comments
 (0)