Skip to content

Commit 16820a5

Browse files
llama : correct hparams comparison (#3446)
* fixed floating point comparison issues * updated implementation for hparam comparison to handle inf and NaN * fixed code review comments * minor simplification * rename is_float_eq -> is_float_close --------- Co-authored-by: Cebtenzzre <[email protected]>
1 parent 04b2f43 commit 16820a5

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

llama.cpp

+39-1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std::
125125
}
126126
s = std::move(result);
127127
}
128+
129+
static bool is_float_close(float a, float b, float abs_tol) {
130+
// Check for non-negative tolerance
131+
if (abs_tol < 0.0) {
132+
throw std::invalid_argument("Tolerance must be non-negative");
133+
}
134+
135+
// Exact equality check
136+
if (a == b) {
137+
return true;
138+
}
139+
140+
// Check for infinities
141+
if (std::isinf(a) || std::isinf(b)) {
142+
return false;
143+
}
144+
145+
// Regular comparison using the provided absolute tolerance
146+
return std::fabs(b - a) <= abs_tol;
147+
}
148+
128149
#ifdef GGML_USE_CPU_HBM
129150
#include <hbwmalloc.h>
130151
#endif
@@ -969,7 +990,24 @@ struct llama_hparams {
969990
float rope_freq_scale_train;
970991

971992
bool operator!=(const llama_hparams & other) const {
972-
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
993+
if (this->vocab_only != other.vocab_only) return true;
994+
if (this->n_vocab != other.n_vocab) return true;
995+
if (this->n_ctx_train != other.n_ctx_train) return true;
996+
if (this->n_embd != other.n_embd) return true;
997+
if (this->n_head != other.n_head) return true;
998+
if (this->n_head_kv != other.n_head_kv) return true;
999+
if (this->n_layer != other.n_layer) return true;
1000+
if (this->n_rot != other.n_rot) return true;
1001+
if (this->n_ff != other.n_ff) return true;
1002+
1003+
const float EPSILON = 1e-9;
1004+
1005+
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
1006+
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
1007+
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
1008+
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
1009+
1010+
return false;
9731011
}
9741012

9751013
uint32_t n_gqa() const {

0 commit comments

Comments
 (0)