Skip to content

Commit 3992df7

Browse files
Nexesenexggerganov
andcommitted
llama : switch KQ multiplication to use F32 precision by default ggml-org#10015
Co-Authored-By: Georgi Gerganov <[email protected]>
1 parent d580455 commit 3992df7

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

src/llama.cpp

+4-11
Original file line numberDiff line numberDiff line change
@@ -9761,20 +9761,16 @@ static struct ggml_tensor * llm_build_kqv(
97619761
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
97629762
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
97639763

9764-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
9765-
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
9766-
}
9764+
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
97679765

97689766
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
97699767
} else {
97709768
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
97719769
cb(kq, "kq", il);
97729770

9773-
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
9774-
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
9775-
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
9776-
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9777-
}
9771+
// note: this op tends to require high floating point range
9772+
// while for some models F16 is enough, for others it is not, so we default to F32 here
9773+
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
97789774

97799775
if (model.arch == LLM_ARCH_GROK) {
97809776
// need to do the following:
@@ -9783,9 +9779,6 @@ static struct ggml_tensor * llm_build_kqv(
97839779
// kq = 30 * tanh(kq / 30)
97849780
// before the softmax below
97859781

9786-
//try from phi2
9787-
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
9788-
97899782
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
97909783
kq = ggml_scale(ctx, kq, 30);
97919784
}

0 commit comments

Comments
 (0)