|
22 | 22 | #include "shaderop_mul_mat_q4_1.h"
|
23 | 23 | #include "shaderop_mul_mat_q6_k.h"
|
24 | 24 | #include "shaderop_mul_mat_mat_f32.h"
|
| 25 | +#include "shaderop_getrows_f32.h" |
25 | 26 | #include "shaderop_getrows_f16.h"
|
26 | 27 | #include "shaderop_getrows_q4_0.h"
|
27 | 28 | #include "shaderop_getrows_q4_1.h"
|
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
|
1146 | 1147 | seq.record<kp::OpAlgoDispatch>(s_algo);
|
1147 | 1148 | }
|
1148 | 1149 |
|
| 1150 | +template <typename... Args> |
| 1151 | +static void ggml_vk_get_rows_f32(Args&&... args) { |
| 1152 | + const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv, |
| 1153 | + kp::shader_data::op_getrows_f32_comp_spv_len); |
| 1154 | + |
| 1155 | + ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...); |
| 1156 | +} |
| 1157 | + |
1149 | 1158 | template <typename... Args>
|
1150 | 1159 | static void ggml_vk_get_rows_f16(Args&&... args) {
|
1151 | 1160 | const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
|
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
1371 | 1380 | return op->ne[3] == 1;
|
1372 | 1381 | case GGML_OP_GET_ROWS:
|
1373 | 1382 | switch (op->src[0]->type) {
|
| 1383 | + case GGML_TYPE_F32: |
1374 | 1384 | case GGML_TYPE_F16:
|
1375 | 1385 | case GGML_TYPE_Q4_0:
|
1376 | 1386 | case GGML_TYPE_Q4_1:
|
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
1661 | 1671 | } break;
|
1662 | 1672 | case GGML_OP_GET_ROWS:
|
1663 | 1673 | {
|
1664 |
| - if (src0t == GGML_TYPE_F16) { |
| 1674 | + if (src0t == GGML_TYPE_F32) { |
| 1675 | + ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); |
| 1676 | + } else if (src0t == GGML_TYPE_F16) { |
1665 | 1677 | ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
1666 | 1678 | } else if (src0t == GGML_TYPE_Q4_0) {
|
1667 | 1679 | ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
|
0 commit comments