Skip to content

Commit 2b4ac77

Browse files
committed
kompute: implement op_getrows_f32
op_getrows_f32 is required since ggml-org#6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again.
1 parent 37e7854 commit 2b4ac77

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ if (LLAMA_KOMPUTE)
709709
kompute-shaders/op_mul_mat_q4_0.comp
710710
kompute-shaders/op_mul_mat_q4_1.comp
711711
kompute-shaders/op_mul_mat_q6_k.comp
712+
kompute-shaders/op_getrows_f32.comp
712713
kompute-shaders/op_getrows_f16.comp
713714
kompute-shaders/op_getrows_q4_0.comp
714715
kompute-shaders/op_getrows_q4_1.comp
@@ -741,6 +742,7 @@ if (LLAMA_KOMPUTE)
741742
shaderop_mul_mat_q4_0.h
742743
shaderop_mul_mat_q4_1.h
743744
shaderop_mul_mat_q6_k.h
745+
shaderop_getrows_f32.h
744746
shaderop_getrows_f16.h
745747
shaderop_getrows_q4_0.h
746748
shaderop_getrows_q4_1.h

ggml-kompute.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "shaderop_mul_mat_q4_1.h"
2323
#include "shaderop_mul_mat_q6_k.h"
2424
#include "shaderop_mul_mat_mat_f32.h"
25+
#include "shaderop_getrows_f32.h"
2526
#include "shaderop_getrows_f16.h"
2627
#include "shaderop_getrows_q4_0.h"
2728
#include "shaderop_getrows_q4_1.h"
@@ -136,8 +137,7 @@ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_devi
136137

137138
physical_device.getFeatures2(&features2);
138139

139-
if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
140-
!availableFeatures11.storageBuffer16BitAccess) {
140+
if (!availableFeatures11.storageBuffer16BitAccess) {
141141
return false;
142142
}
143143

@@ -1146,6 +1146,14 @@ static void ggml_vk_get_rows(
11461146
seq.record<kp::OpAlgoDispatch>(s_algo);
11471147
}
11481148

1149+
template <typename... Args>
1150+
static void ggml_vk_get_rows_f32(Args&&... args) {
1151+
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1152+
kp::shader_data::op_getrows_f32_comp_spv_len);
1153+
1154+
ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1155+
}
1156+
11491157
template <typename... Args>
11501158
static void ggml_vk_get_rows_f16(Args&&... args) {
11511159
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
@@ -1371,6 +1379,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
13711379
return op->ne[3] == 1;
13721380
case GGML_OP_GET_ROWS:
13731381
switch (op->src[0]->type) {
1382+
case GGML_TYPE_F32:
13741383
case GGML_TYPE_F16:
13751384
case GGML_TYPE_Q4_0:
13761385
case GGML_TYPE_Q4_1:
@@ -1649,7 +1658,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
16491658
} break;
16501659
case GGML_OP_GET_ROWS:
16511660
{
1652-
if (src0t == GGML_TYPE_F16) {
1661+
if (src0t == GGML_TYPE_F32) {
1662+
ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1663+
} else if (src0t == GGML_TYPE_F16) {
16531664
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
16541665
} else if (src0t == GGML_TYPE_Q4_0) {
16551666
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));

kompute-shaders/op_getrows_f32.comp

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450
2+
3+
#include "common.comp"
4+
5+
layout(local_size_x = 1) in;
6+
7+
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
8+
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
9+
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
10+
11+
layout (push_constant) uniform parameter {
12+
uint inAOff;
13+
uint inBOff;
14+
uint outOff;
15+
int ne00;
16+
int nb01;
17+
int nb1;
18+
} pcs;
19+
20+
void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
21+
for (int j = 0; j < k; j++) {
22+
out_[y + j] = inA[x + j];
23+
}
24+
}
25+
26+
void main() {
27+
const uint i = gl_WorkGroupID.x;
28+
const int r = inB[i + pcs.inBOff];
29+
30+
dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
31+
}

0 commit comments

Comments
 (0)