Skip to content

Commit e28245f

Browse files
authored
export-lora : fix tok_embd tensor (#11330)
1 parent 6da5bec commit e28245f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/export-lora/export-lora.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,18 @@ struct lora_merge_ctx {
345345
gf = ggml_new_graph(ctx0);
346346
struct ggml_tensor * cur = inp_base;
347347
for (size_t i = 0; i < adapters.size(); ++i) {
348-
struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)));
349-
struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
348+
struct ggml_tensor * delta;
349+
bool is_tok_embd = string_starts_with(name_base, "token_embd");
350+
if (is_tok_embd) {
351+
printf("%s : detected token embeddings tensor\n", __func__);
352+
delta = ggml_mul_mat(ctx0,
353+
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
354+
ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
355+
} else {
356+
delta = ggml_mul_mat(ctx0,
357+
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
358+
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
359+
}
350360
// scale
351361
const float alpha = adapters[i]->alpha;
352362
const float rank = (float) inp_b[i]->ne[0];

0 commit comments

Comments
 (0)