Skip to content

Commit 32a032f

Browse files
ggerganovdsx1986
authored andcommitted
ggml : add asserts for type conversion in fattn kernels (ggml-org#9971)
ggml-ci
1 parent f37cee5 commit 32a032f

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

common/common.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
10351035
return GGML_TYPE_Q5_1;
10361036
}
10371037

1038-
throw std::runtime_error("Invalid cache type: " + s);
1038+
throw std::runtime_error("Unsupported cache type: " + s);
10391039
}
10401040

10411041
struct llama_context_params common_context_params_to_llama(const common_params & params) {
@@ -1047,7 +1047,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
10471047
cparams.n_ubatch = params.n_ubatch;
10481048
cparams.n_threads = params.cpuparams.n_threads;
10491049
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
1050-
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1050+
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
10511051
cparams.logits_all = params.logits_all;
10521052
cparams.embeddings = params.embedding;
10531053
cparams.rope_scaling_type = params.rope_scaling_type;

ggml/src/ggml.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,9 @@ struct ggml_logger_state {
324324
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
325325

326326
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
327-
if (format == NULL)
327+
if (format == NULL) {
328328
return;
329+
}
329330
va_list args_copy;
330331
va_copy(args_copy, args);
331332
char buffer[128];
@@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1572315724
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
1572415725
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
1572515726

15727+
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
15728+
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
15729+
1572615730
// loop over n_batch and n_head
1572715731
for (int ir = ir0; ir < ir1; ++ir) {
1572815732
// q indices

src/llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -19243,7 +19243,7 @@ struct llama_context * llama_new_context_with_model(
1924319243
params.flash_attn = false;
1924419244
}
1924519245

19246-
if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
19246+
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
1924719247
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
1924819248
return nullptr;
1924919249
}

0 commit comments

Comments
 (0)