@@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std::
125
125
}
126
126
s = std::move (result);
127
127
}
128
+
129
+ static bool is_float_close (float a, float b, float abs_tol) {
130
+ // Check for non-negative tolerance
131
+ if (abs_tol < 0.0 ) {
132
+ throw std::invalid_argument (" Tolerance must be non-negative" );
133
+ }
134
+
135
+ // Exact equality check
136
+ if (a == b) {
137
+ return true ;
138
+ }
139
+
140
+ // Check for infinities
141
+ if (std::isinf (a) || std::isinf (b)) {
142
+ return false ;
143
+ }
144
+
145
+ // Regular comparison using the provided absolute tolerance
146
+ return std::fabs (b - a) <= abs_tol;
147
+ }
148
+
128
149
#ifdef GGML_USE_CPU_HBM
129
150
#include < hbwmalloc.h>
130
151
#endif
@@ -969,7 +990,24 @@ struct llama_hparams {
969
990
float rope_freq_scale_train;
970
991
971
992
bool operator !=(const llama_hparams & other) const {
972
- return static_cast <bool >(memcmp (this , &other, sizeof (llama_hparams))); // NOLINT
993
+ if (this ->vocab_only != other.vocab_only ) return true ;
994
+ if (this ->n_vocab != other.n_vocab ) return true ;
995
+ if (this ->n_ctx_train != other.n_ctx_train ) return true ;
996
+ if (this ->n_embd != other.n_embd ) return true ;
997
+ if (this ->n_head != other.n_head ) return true ;
998
+ if (this ->n_head_kv != other.n_head_kv ) return true ;
999
+ if (this ->n_layer != other.n_layer ) return true ;
1000
+ if (this ->n_rot != other.n_rot ) return true ;
1001
+ if (this ->n_ff != other.n_ff ) return true ;
1002
+
1003
+ const float EPSILON = 1e-9 ;
1004
+
1005
+ if (!is_float_close (this ->f_norm_eps , other.f_norm_eps , EPSILON)) return true ;
1006
+ if (!is_float_close (this ->f_norm_rms_eps , other.f_norm_rms_eps , EPSILON)) return true ;
1007
+ if (!is_float_close (this ->rope_freq_base_train , other.rope_freq_base_train , EPSILON)) return true ;
1008
+ if (!is_float_close (this ->rope_freq_scale_train , other.rope_freq_scale_train , EPSILON)) return true ;
1009
+
1010
+ return false ;
973
1011
}
974
1012
975
1013
uint32_t n_gqa () const {
0 commit comments