Skip to content

Commit 8841ce3

Browse files
authored
llama : switch KQ multiplication to F32 precision by default (#10015)
ggml-ci
1 parent cc2983d commit 8841ce3

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

src/llama.cpp

+4-11
Original file line numberDiff line numberDiff line change
@@ -9618,20 +9618,16 @@ static struct ggml_tensor * llm_build_kqv(
96189618
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
96199619
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
96209620

9621-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
9622-
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
9623-
}
9621+
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
96249622

96259623
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
96269624
} else {
96279625
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
96289626
cb(kq, "kq", il);
96299627

9630-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
9631-
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
9632-
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
9633-
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9634-
}
9628+
// note: this op tends to require high floating point range
9629+
// while for some models F16 is enough, for others it is not, so we default to F32 here
9630+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
96359631

96369632
if (model.arch == LLM_ARCH_GROK) {
96379633
// need to do the following:
@@ -9640,9 +9636,6 @@ static struct ggml_tensor * llm_build_kqv(
96409636
// kq = 30 * tanh(kq / 30)
96419637
// before the softmax below
96429638

9643-
//try from phi2
9644-
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9645-
96469639
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
96479640
kq = ggml_scale(ctx, kq, 30);
96489641
}

0 commit comments

Comments
 (0)