@@ -345,8 +345,18 @@ struct lora_merge_ctx {
345
345
gf = ggml_new_graph (ctx0);
346
346
struct ggml_tensor * cur = inp_base;
347
347
for (size_t i = 0 ; i < adapters.size (); ++i) {
348
- struct ggml_tensor * a_T = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32)));
349
- struct ggml_tensor * delta = ggml_mul_mat (ctx0, a_T, ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32));
348
+ struct ggml_tensor * delta;
349
+ bool is_tok_embd = string_starts_with (name_base, " token_embd" );
350
+ if (is_tok_embd) {
351
+ printf (" %s : detected token embeddings tensor\n " , __func__);
352
+ delta = ggml_mul_mat (ctx0,
353
+ ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32),
354
+ ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32));
355
+ } else {
356
+ delta = ggml_mul_mat (ctx0,
357
+ ggml_cont (ctx0, ggml_transpose (ctx0, ggml_cast (ctx0, inp_a[i], GGML_TYPE_F32))),
358
+ ggml_cast (ctx0, inp_b[i], GGML_TYPE_F32));
359
+ }
350
360
// scale
351
361
const float alpha = adapters[i]->alpha ;
352
362
const float rank = (float ) inp_b[i]->ne [0 ];
0 commit comments