@@ -4734,7 +4734,8 @@ struct ggml_tensor * ggml_get_rows(
4734
4734
struct ggml_tensor * a,
4735
4735
struct ggml_tensor * b) {
4736
4736
GGML_ASSERT(a->ne[2] == b->ne[1]);
4737
- GGML_ASSERT(ggml_is_matrix(b) && b->type == GGML_TYPE_I32);
4737
+ GGML_ASSERT(b->ne[3] == 1);
4738
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
4738
4739
4739
4740
bool is_node = false;
4740
4741
@@ -4744,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows(
4744
4745
4745
4746
// TODO: implement non F32 return
4746
4747
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4747
- struct ggml_tensor * result = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1]);
4748
+ struct ggml_tensor * result = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2 ]);
4748
4749
4749
4750
result->op = GGML_OP_GET_ROWS;
4750
4751
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -10414,22 +10415,24 @@ static void ggml_compute_forward_get_rows_f32(
10414
10415
GGML_TENSOR_BINARY_OP_LOCALS
10415
10416
10416
10417
const int64_t nc = ne00;
10417
- const int64_t nr = ggml_nelements(src1);
10418
10418
10419
10419
assert(ne0 == nc);
10420
10420
assert(ne02 == ne11);
10421
10421
assert(nb00 == sizeof(float));
10422
10422
assert(ggml_nrows(dst) == nr);
10423
10423
10424
10424
// TODO: multi-thread
10425
- for (int64_t i = 0; i < nr; ++i) {
10426
- const int64_t r = ((int32_t *) src1->data)[i];
10427
-
10428
- const int64_t i02 = i/ne10;
10429
-
10430
- ggml_vec_cpy_f32(nc,
10431
- (float *) ((char *) dst->data + i*nb1),
10432
- (float *) ((char *) src0->data + i02*nb02 + r*nb01));
10425
+ // TODO: same impl for get_rows_q and get_rows_f16
10426
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10427
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10428
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10429
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10430
+
10431
+ ggml_vec_cpy_f32(nc,
10432
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
10433
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
10434
+ }
10435
+ }
10433
10436
}
10434
10437
}
10435
10438
0 commit comments