Skip to content

Commit 5aea5e3

Browse files
ggerganovarthw
authored andcommitted
llama : switch KQ multiplication to F32 precision by default (ggml-org#10015)
ggml-ci
1 parent fc74110 commit 5aea5e3

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

src/llama.cpp

+4-11
Original file line numberDiff line numberDiff line change
@@ -9624,20 +9624,16 @@ static struct ggml_tensor * llm_build_kqv(
96249624
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
96259625
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
96269626

9627-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
9628-
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
9629-
}
9627+
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
96309628

96319629
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
96329630
} else {
96339631
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
96349632
cb(kq, "kq", il);
96359633

9636-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
9637-
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
9638-
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
9639-
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9640-
}
9634+
// note: this op tends to require high floating point range
9635+
// while for some models F16 is enough, for others it is not, so we default to F32 here
9636+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
96419637

96429638
if (model.arch == LLM_ARCH_GROK) {
96439639
// need to do the following:
@@ -9646,9 +9642,6 @@ static struct ggml_tensor * llm_build_kqv(
96469642
// kq = 30 * tanh(kq / 30)
96479643
// before the softmax below
96489644

9649-
//try from phi2
9650-
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9651-
96529645
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
96539646
kq = ggml_scale(ctx, kq, 30);
96549647
}

0 commit comments

Comments
 (0)