From 4ed3469c684a6f807c828019f10ed8b987807f53 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 18:35:28 +0300 Subject: [PATCH 01/36] llama : refactor GGUF constants into static maps --- llama.cpp | 249 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 185 insertions(+), 64 deletions(-) diff --git a/llama.cpp b/llama.cpp index 8b151dc84c90c..c538e74a38396 100644 --- a/llama.cpp +++ b/llama.cpp @@ -80,20 +80,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -// tensor names -#define TN_TOKEN_EMBD "token_embd.weight" -#define TN_OUTPUT_NORM "output_norm.weight" -#define TN_OUTPUT "output.weight" -#define TN_ATTN_NORM "blk.%d.attn_norm.weight" -#define TN_ATTN_Q "blk.%d.attn_q.weight" -#define TN_ATTN_K "blk.%d.attn_k.weight" -#define TN_ATTN_V "blk.%d.attn_v.weight" -#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight" -#define TN_FFN_NORM "blk.%d.ffn_norm.weight" -#define TN_FFN_GATE "blk.%d.ffn_gate.weight" -#define TN_FFN_DOWN "blk.%d.ffn_down.weight" -#define TN_FFN_UP "blk.%d.ffn_up.weight" - #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -107,6 +93,7 @@ // // logging // + LLAMA_ATTRIBUTE_FORMAT(2, 3) static void llama_log_internal (llama_log_level level, const char* format, ...); static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); @@ -142,6 +129,141 @@ static std::string format(const char * fmt, ...) { return std::string(buf.data(), size); } +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_UNKNOWN, +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, +}; + +static std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, +}; + +static std::map> LLM_TENSOR_NAMES_BASE = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, +}; + +static llm_arch llama_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +// helper to handle gguf constants +// usage: +// +// const auto llm = LLM(); +// +// std::string name = LLM(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = LLM(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = LLM(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +template +struct LLM { + std::string operator()(llm_tensor tensor) const { + return LLM_TENSOR_NAMES_BASE[T].at(tensor); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix) const { + return LLM_TENSOR_NAMES_BASE[T].at(tensor) + "." + suffix; + } + + std::string operator()(llm_tensor tensor, int bid) const { + return format(LLM_TENSOR_NAMES_BASE[T].at(tensor).c_str(), bid); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { + return format(LLM_TENSOR_NAMES_BASE[T].at(tensor).c_str(), bid) + "." + suffix; + } +}; + +// +// gguf helpers +// + +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ +{ \ + const int kid = gguf_find_key(ctx, key); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %s", key, gguf_type_name(ktype))); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", key)); \ + } \ +} + // // ggml helpers // @@ -594,7 +716,7 @@ enum e_model { }; static const size_t kB = 1024; -static const size_t MB = 1024*1024; +static const size_t MB = kB*kB; // default hparams (LLaMA 7B) struct llama_hparams { @@ -1270,22 +1392,8 @@ static void llama_model_load_internal( { struct gguf_context * ctx = ml->ctx_gguf; -#define GGUF_GET(dst, func, type, req, key) \ - { \ - const int kid = gguf_find_key(ctx, key); \ - if (kid >= 0) { \ - enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ - if (ktype != (type)) { \ - throw std::runtime_error(format("key %s has wrong type: %s", key, gguf_type_name(ktype))); \ - } \ - (dst) = func(ctx, kid); \ - } else if (req) { \ - throw std::runtime_error(format("key not found in model: %s", key)); \ - } \ - } - std::string tokenizer_name; - GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); + GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; @@ -1298,39 +1406,37 @@ static void llama_model_load_internal( } // get hparams kv - GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); - GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); - GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); - GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); - GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); - GGUF_GET(hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); - GGUF_GET(hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); - GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; - GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); + GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); // TODO: manually setting rope scale should override this // rope_freq_scale (inverse of the kv) is optional float ropescale = 1.0f; - GGUF_GET(ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); if (ropescale != 1.0f) { rope_freq_scale = 1.0f/ropescale; } // get general kv - GGUF_GET(general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); - GGUF_GET(general_arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); + GGUF_GET_KEY(ctx, general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); + GGUF_GET_KEY(ctx, general_arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); // special tokens - GGUF_GET(vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); - GGUF_GET(vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); - GGUF_GET(vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); - GGUF_GET(vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); - GGUF_GET(vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); - -#undef GGUF_GET + GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); + GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); + GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); + GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); + GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1500,7 +1606,9 @@ static void llama_model_load_internal( const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; - model.tok_embeddings = ml->create_tensor(ctx, TN_TOKEN_EMBD, {n_embd, n_vocab}, GGML_BACKEND_CPU); + const auto llm = LLM(); + + model.tok_embeddings = ml->create_tensor(ctx, llm(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // "output" tensor { @@ -1521,8 +1629,8 @@ static void llama_model_load_internal( backend_output = GGML_BACKEND_CPU; } - model.norm = ml->create_tensor(ctx, TN_OUTPUT_NORM, {n_embd}, backend_norm); - model.output = ml->create_tensor(ctx, TN_OUTPUT, {n_embd, n_vocab}, backend_output); + model.norm = ml->create_tensor(ctx, llm(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml->create_tensor(ctx, llm(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } @@ -1541,18 +1649,18 @@ static void llama_model_load_internal( const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; - layer.attention_norm = ml->create_tensor(ctx, format(TN_ATTN_NORM, i), {n_embd}, backend); + layer.attention_norm = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wq = ml->create_tensor(ctx, format(TN_ATTN_Q, i), {n_embd, n_embd}, backend_split); - layer.wk = ml->create_tensor(ctx, format(TN_ATTN_K, i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml->create_tensor(ctx, format(TN_ATTN_V, i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml->create_tensor(ctx, format(TN_ATTN_OUTPUT, i), {n_embd, n_embd}, backend_split); + layer.wq = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml->create_tensor(ctx, format(TN_FFN_NORM, i), {n_embd}, backend); + layer.ffn_norm = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w1 = ml->create_tensor(ctx, format(TN_FFN_GATE, i), {n_embd, n_ff}, backend_split); - layer.w2 = ml->create_tensor(ctx, format(TN_FFN_DOWN, i), { n_ff, n_embd}, backend_split); - layer.w3 = ml->create_tensor(ctx, format(TN_FFN_UP, i), {n_embd, n_ff}, backend_split); + layer.w1 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1671,6 +1779,16 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { + std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); + + std::string arch_name; + GGUF_GET_KEY(ml->ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "general.architecture"); + + const llm_arch arch = llama_arch_from_string(arch_name); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown architecture: " + arch_name); + } + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); @@ -3540,7 +3658,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = quantized_type; #ifdef GGML_USE_K_QUANTS // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name == TN_OUTPUT) { + const auto llm = LLM(); + + if (name == llm(LLM_TENSOR_OUTPUT, "weight")) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K == 0 && ny % QK_K == 0) { @@ -3576,10 +3696,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } if (convert_incompatible_tensor) { - if (name == TN_OUTPUT) { + if (name == llm(LLM_TENSOR_OUTPUT, "weight")) { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); - } else if (name == TN_TOKEN_EMBD) { + } else if (name == llm(LLM_TENSOR_TOKEN_EMBD, "weight")) { new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); } else { @@ -3891,6 +4011,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // load from base model if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { + // TODO: throw LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); return 1; } From 8bd7f06b5891c9b0d1fb73604e04dcff4c1fe175 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 19:03:08 +0300 Subject: [PATCH 02/36] llama : check if model architecture is known --- llama.cpp | 56 +++++++++++++++++++++++++++---------------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index c538e74a38396..618d1773dec58 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1359,7 +1359,7 @@ static const char * llama_model_type_name(e_model type) { } static void llama_model_load_internal( - const std::string & fname, + llama_model_loader & ml, llama_model & model, llama_vocab & vocab, int n_ctx, @@ -1379,8 +1379,6 @@ static void llama_model_load_internal( void * progress_callback_user_data) { model.t_start_us = ggml_time_us(); - std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); - model.n_gpu_layers = n_gpu_layers; auto & hparams = model.hparams; @@ -1390,7 +1388,7 @@ static void llama_model_load_internal( // read hparams { - struct gguf_context * ctx = ml->ctx_gguf; + struct gguf_context * ctx = ml.ctx_gguf; std::string tokenizer_name; GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); @@ -1452,7 +1450,7 @@ static void llama_model_load_internal( } break; } - model.ftype = ml->ftype; + model.ftype = ml.ftype; hparams.n_ctx = n_ctx; @@ -1473,7 +1471,7 @@ static void llama_model_load_internal( // read vocab { - struct gguf_context * ctx = ml->ctx_gguf; + struct gguf_context * ctx = ml.ctx_gguf; vocab.id_to_token.resize(hparams.n_vocab); @@ -1515,7 +1513,7 @@ static void llama_model_load_internal( { // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver)); + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str()); LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); @@ -1533,7 +1531,7 @@ static void llama_model_load_internal( LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype)); - LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9); + LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); // general kv LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, general_name.c_str()); @@ -1557,7 +1555,7 @@ static void llama_model_load_internal( size_t ctx_size; size_t mmapped_size; - ml->calc_sizes(ctx_size, mmapped_size); + ml.calc_sizes(ctx_size, mmapped_size); LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); @@ -1572,7 +1570,7 @@ static void llama_model_load_internal( struct ggml_init_params params = { /*.mem_size =*/ model.buf.size, /*.mem_buffer =*/ model.buf.data, - /*.no_alloc =*/ ml->use_mmap, + /*.no_alloc =*/ ml.use_mmap, }; model.ctx = ggml_init(params); @@ -1608,7 +1606,7 @@ static void llama_model_load_internal( const auto llm = LLM(); - model.tok_embeddings = ml->create_tensor(ctx, llm(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embeddings = ml.create_tensor(ctx, llm(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // "output" tensor { @@ -1629,8 +1627,8 @@ static void llama_model_load_internal( backend_output = GGML_BACKEND_CPU; } - model.norm = ml->create_tensor(ctx, llm(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml->create_tensor(ctx, llm(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.norm = ml.create_tensor(ctx, llm(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, llm(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } @@ -1649,18 +1647,18 @@ static void llama_model_load_internal( const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; - layer.attention_norm = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attention_norm = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wq = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); - layer.wk = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml->create_tensor(ctx, llm(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.wq = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w1 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml->create_tensor(ctx, llm(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.w1 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1671,7 +1669,7 @@ static void llama_model_load_internal( } } - ml->done_getting_tensors(); + ml.done_getting_tensors(); // print memory requirements { @@ -1734,8 +1732,8 @@ static void llama_model_load_internal( } // populate `tensors_by_name` - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, ml->get_tensor_name(i)); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); } @@ -1746,13 +1744,13 @@ static void llama_model_load_internal( } #endif - ml->load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); if (progress_callback) { progress_callback(1.0f, progress_callback_user_data); } - model.mapping = std::move(ml->mapping); + model.mapping = std::move(ml.mapping); // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -1786,10 +1784,10 @@ static bool llama_model_load( const llm_arch arch = llama_arch_from_string(arch_name); if (arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown architecture: " + arch_name); + throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); } - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, + llama_model_load_internal(*ml, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; From 3057d6a687f3cdd5f5fcdd7c13ad0dcbce79969e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 19:30:02 +0300 Subject: [PATCH 03/36] llama : refactor llama_model_load_internal() --- llama.cpp | 344 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 181 insertions(+), 163 deletions(-) diff --git a/llama.cpp b/llama.cpp index 618d1773dec58..ef711fa117f99 100644 --- a/llama.cpp +++ b/llama.cpp @@ -208,7 +208,7 @@ static std::map> LLM_TENSOR_NAMES_BA }, }; -static llm_arch llama_arch_from_string(const std::string & name) { +static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { return kv.first; @@ -836,6 +836,9 @@ struct llama_model { e_model type = MODEL_UNKNOWN; llama_ftype ftype = LLAMA_FTYPE_ALL_F32; + std::string name = "n/a"; + std::string arch = "n/a"; + llama_hparams hparams; llama_vocab vocab; @@ -1358,38 +1361,34 @@ static const char * llama_model_type_name(e_model type) { } } -static void llama_model_load_internal( +static void llm_load_vocab( llama_model_loader & ml, - llama_model & model, - llama_vocab & vocab, - int n_ctx, - int n_batch, - int n_gpu_layers, - int main_gpu, - const float * tensor_split, - const bool mul_mat_q, - float rope_freq_base, - float rope_freq_scale, - bool low_vram, - ggml_type memory_type, - bool use_mmap, - bool use_mlock, - bool vocab_only, - llama_progress_callback progress_callback, - void * progress_callback_user_data) { - model.t_start_us = ggml_time_us(); + llama_model & model) { + auto & vocab = model.vocab; - model.n_gpu_layers = n_gpu_layers; + struct gguf_context * ctx = ml.ctx_gguf; - auto & hparams = model.hparams; + const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens"); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } - std::string general_name = "n/a"; - std::string general_arch = "n/a"; + const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores"); + if (score_idx == -1) { + throw std::runtime_error("cannot find tokenizer scores in model file\n"); + } - // read hparams - { - struct gguf_context * ctx = ml.ctx_gguf; + const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + + const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + // determine vocab type + { std::string tokenizer_name; GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); @@ -1402,155 +1401,161 @@ static void llama_model_load_internal( LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); vocab.type = LLAMA_VOCAB_TYPE_SPM; } + } - // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); - GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); - - // n_head_kv is optional, default to n_head - hparams.n_head_kv = hparams.n_head; - GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); - - // TODO: manually setting rope scale should override this - // rope_freq_scale (inverse of the kv) is optional - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); - if (ropescale != 1.0f) { - rope_freq_scale = 1.0f/ropescale; - } - - // get general kv - GGUF_GET_KEY(ctx, general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); - GGUF_GET_KEY(ctx, general_arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); - - // special tokens - GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); - GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); - GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); - GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); - GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); - - switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: - { - if (hparams.n_layer < 32) { - model.type = e_model::MODEL_7B; - } - } break; - } + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); - model.ftype = ml.ftype; + vocab.id_to_token.resize(n_vocab); - hparams.n_ctx = n_ctx; + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); - // LLaMAv2 - // TODO: probably not needed - { - const auto n_gqa = hparams.n_gqa(); + vocab.token_to_id[word] = i; - if (model.type == e_model::MODEL_65B && n_gqa == 8) { - LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); - model.type = e_model::MODEL_70B; - } - } + auto & token_data = vocab.id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores[i]; + token_data.type = (llama_token_type) toktypes[i]; - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; + // determine the newline token: 0x0A == 10 == '\n' + if (token_data.text == "<0x0A>") { + vocab.linefeed_id = i; + } } - // read vocab - { - struct gguf_context * ctx = ml.ctx_gguf; - - vocab.id_to_token.resize(hparams.n_vocab); + // special tokens + GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); + GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); + GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); + GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); + GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); +} - const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens"); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } +static void llm_load_hparams( + llama_model_loader & ml, + llama_model & model, + int n_ctx, + float rope_freq_base, + float rope_freq_scale) { + auto & hparams = model.hparams; - const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores"); - if (score_idx == -1) { - throw std::runtime_error("cannot find tokenizer scores in model file\n"); - } + struct gguf_context * ctx = ml.ctx_gguf; + + // get hparams kv + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); + + // TODO: manually setting rope scale should override this + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); + if (ropescale != 1.0f) { + rope_freq_scale = 1.0f/ropescale; + } + + // get general kv + GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); + GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = e_model::MODEL_65B; break; + default: + { + if (hparams.n_layer < 32) { + model.type = e_model::MODEL_7B; + } + } break; + } - const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + model.ftype = ml.ftype; - const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); - if (toktype_idx == -1) { - throw std::runtime_error("cannot find token type list in GGUF file\n"); - } + hparams.n_ctx = n_ctx; - const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + // LLaMAv2 + // TODO: probably not needed + { + const auto n_gqa = hparams.n_gqa(); - for (uint32_t i = 0; i < hparams.n_vocab; i++) { - std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + } + } - vocab.token_to_id[word] = i; + hparams.rope_freq_base = rope_freq_base; + hparams.rope_freq_scale = rope_freq_scale; +} - auto & token_data = vocab.id_to_token[i]; - token_data.text = std::move(word); - token_data.score = scores[i]; - token_data.type = (llama_token_type) toktypes[i]; +static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { + const auto & hparams = model.hparams; + const auto & vocab = model.vocab; + + // hparams + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, model.arch.c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype)); + LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); + + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); + + // special tokens + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } +} - // determine the newline token: 0x0A == 10 == '\n' - if (token_data.text == "<0x0A>") { - vocab.linefeed_id = i; - } - } - } +static void llama_model_load_internal( + llama_model_loader & ml, + llama_model & model, + int n_batch, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + const bool mul_mat_q, + bool low_vram, + ggml_type memory_type, + bool use_mlock, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + model.t_start_us = ggml_time_us(); - { - // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype)); - LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); - - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, general_name.c_str()); - - // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } - } - - if (vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return; - } + auto & ctx = model.ctx; + auto & hparams = model.hparams; - auto & ctx = model.ctx; + model.n_gpu_layers = n_gpu_layers; size_t ctx_size; size_t mmapped_size; @@ -1760,7 +1765,6 @@ static void llama_model_load_internal( static bool llama_model_load( const std::string & fname, llama_model & model, - llama_vocab & vocab, int n_ctx, int n_batch, int n_gpu_layers, @@ -1782,14 +1786,28 @@ static bool llama_model_load( std::string arch_name; GGUF_GET_KEY(ml->ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "general.architecture"); - const llm_arch arch = llama_arch_from_string(arch_name); + const llm_arch arch = llm_arch_from_string(arch_name); if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); } - llama_model_load_internal(*ml, model, vocab, n_ctx, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, - use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); + llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); + llm_load_vocab(*ml, model); + + llm_load_print_meta(*ml, model); + + if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + throw std::runtime_error("vocab size mismatch"); + } + + if (vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return true; + } + + llama_model_load_internal(*ml, model, n_batch, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + use_mlock, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); @@ -4191,7 +4209,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, + if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { From 3c025a6d07dbd2799bb949a97956cd33048a02cc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 20:06:15 +0300 Subject: [PATCH 04/36] gguf : add KV constant maps --- gguf.py | 26 ++-- llama.cpp | 392 ++++++++++++++++++++++++++++++++++++------------------ 2 files changed, 272 insertions(+), 146 deletions(-) diff --git a/gguf.py b/gguf.py index 9776649c76119..b37f619668c7e 100644 --- a/gguf.py +++ b/gguf.py @@ -28,12 +28,12 @@ KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" # LLM -KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length" -KEY_LLM_EMBEDDING_LENGTH = "{arch}.embedding_length" -KEY_LLM_BLOCK_COUNT = "{arch}.block_count" -KEY_LLM_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" -KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" -KEY_LLM_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" +KEY_CONTEXT_LENGTH = "{arch}.context_length" +KEY_EMBEDDING_LENGTH = "{arch}.embedding_length" +KEY_BLOCK_COUNT = "{arch}.block_count" +KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" +KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" +KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" # attention KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count" @@ -581,7 +581,7 @@ def add_author(self, author: str): self.add_string(KEY_GENERAL_AUTHOR, author) def add_tensor_data_layout(self, layout: str): - self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) + self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) def add_url(self, url: str): self.add_string(KEY_GENERAL_URL, url) @@ -608,27 +608,27 @@ def add_custom_alignment(self, alignment: int): def add_context_length(self, length: int): self.add_uint32( - KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length) + KEY_CONTEXT_LENGTH.format(arch=self.arch), length) def add_embedding_length(self, length: int): self.add_uint32( - KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length) + KEY_EMBEDDING_LENGTH.format(arch=self.arch), length) def add_block_count(self, length: int): self.add_uint32( - KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length) + KEY_BLOCK_COUNT.format(arch=self.arch), length) def add_feed_forward_length(self, length: int): self.add_uint32( - KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length) def add_parallel_residual(self, use: bool): self.add_bool( - KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) + KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) def add_tensor_data_layout(self, layout: str): self.add_string( - KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) + KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) def add_head_count(self, count: int): self.add_uint32( diff --git a/llama.cpp b/llama.cpp index ef711fa117f99..2632052c87fa0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -143,6 +143,111 @@ enum llm_arch { LLM_ARCH_UNKNOWN, }; +static std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, +}; + +enum llm_kv { + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_SCALE_LINEAR, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, +}; + +static std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source_url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source_hf_repo" }, + + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_kv kv) const { + return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str()); + } +}; + enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_POS_EMBD, @@ -163,16 +268,7 @@ enum llm_tensor { LLM_TENSOR_FFN_NORM, }; -static std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, -}; - -static std::map> LLM_TENSOR_NAMES_BASE = { +static std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, { @@ -221,46 +317,49 @@ static llm_arch llm_arch_from_string(const std::string & name) { // helper to handle gguf constants // usage: // -// const auto llm = LLM(); +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); // -// std::string name = LLM(LLM_TENSOR_OUTPUT); -> "output" -// std::string name = LLM(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" -// std::string name = LLM(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" // -template -struct LLM { +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + std::string operator()(llm_tensor tensor) const { - return LLM_TENSOR_NAMES_BASE[T].at(tensor); + return LLM_TENSOR_NAMES[arch].at(tensor); } std::string operator()(llm_tensor tensor, const std::string & suffix) const { - return LLM_TENSOR_NAMES_BASE[T].at(tensor) + "." + suffix; + return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; } std::string operator()(llm_tensor tensor, int bid) const { - return format(LLM_TENSOR_NAMES_BASE[T].at(tensor).c_str(), bid); + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); } std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - return format(LLM_TENSOR_NAMES_BASE[T].at(tensor).c_str(), bid) + "." + suffix; + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; } }; - // // gguf helpers // #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ { \ - const int kid = gguf_find_key(ctx, key); \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ if (kid >= 0) { \ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ if (ktype != (type)) { \ - throw std::runtime_error(format("key %s has wrong type: %s", key, gguf_type_name(ktype))); \ + throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ } \ (dst) = func(ctx, kid); \ } else if (req) { \ - throw std::runtime_error(format("key not found in model: %s", key)); \ + throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ } \ } @@ -1361,26 +1460,117 @@ static const char * llama_model_type_name(e_model type) { } } +static llm_arch llm_load_arch(llama_model_loader & ml) { + struct gguf_context * ctx = ml.ctx_gguf; + + const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); + + std::string arch_name; + GGUF_GET_KEY(ctx, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_GENERAL_ARCHITECTURE)); + + const llm_arch arch = llm_arch_from_string(arch_name); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); + } + + return arch; +} + +static void llm_load_hparams( + llm_arch arch, + llama_model_loader & ml, + llama_model & model, + int n_ctx, + float rope_freq_base, + float rope_freq_scale) { + auto & hparams = model.hparams; + + struct gguf_context * ctx = ml.ctx_gguf; + + const auto kv = LLM_KV(arch); + + // get general kv + GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); + GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); + + // get hparams kv + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); + + // TODO: manually setting rope scale should override this + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (ropescale != 1.0f) { + rope_freq_scale = 1.0f/ropescale; + } + + // TODO: generalize to non-LLaMA models + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = e_model::MODEL_65B; break; + default: + { + if (hparams.n_layer < 32) { + model.type = e_model::MODEL_7B; + } + } break; + } + + // LLaMAv2 + // TODO: probably not needed + { + const auto n_gqa = hparams.n_gqa(); + + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + } + } + + model.ftype = ml.ftype; + + hparams.n_ctx = n_ctx; + hparams.rope_freq_base = rope_freq_base; + hparams.rope_freq_scale = rope_freq_scale; +} + static void llm_load_vocab( + llm_arch arch, llama_model_loader & ml, llama_model & model) { auto & vocab = model.vocab; struct gguf_context * ctx = ml.ctx_gguf; - const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens"); + const auto kv = LLM_KV(arch); + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); if (token_idx == -1) { throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } - const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores"); + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); if (score_idx == -1) { throw std::runtime_error("cannot find tokenizer scores in model file\n"); } const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx == -1) { throw std::runtime_error("cannot find token type list in GGUF file\n"); } @@ -1390,7 +1580,7 @@ static void llm_load_vocab( // determine vocab type { std::string tokenizer_name; - GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); + GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; @@ -1424,80 +1614,11 @@ static void llm_load_vocab( } // special tokens - GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); - GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); - GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); - GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); - GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); -} - -static void llm_load_hparams( - llama_model_loader & ml, - llama_model & model, - int n_ctx, - float rope_freq_base, - float rope_freq_scale) { - auto & hparams = model.hparams; - - struct gguf_context * ctx = ml.ctx_gguf; - - // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); - GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); - - // n_head_kv is optional, default to n_head - hparams.n_head_kv = hparams.n_head; - GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); - - // TODO: manually setting rope scale should override this - // rope_freq_scale (inverse of the kv) is optional - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); - if (ropescale != 1.0f) { - rope_freq_scale = 1.0f/ropescale; - } - - // get general kv - GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); - GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); - - switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: - { - if (hparams.n_layer < 32) { - model.type = e_model::MODEL_7B; - } - } break; - } - - model.ftype = ml.ftype; - - hparams.n_ctx = n_ctx; - - // LLaMAv2 - // TODO: probably not needed - { - const auto n_gqa = hparams.n_gqa(); - - if (model.type == e_model::MODEL_65B && n_gqa == 8) { - LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); - model.type = e_model::MODEL_70B; - } - } - - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; + GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); + GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); + GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); + GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); + GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); } static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { @@ -1537,7 +1658,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } } -static void llama_model_load_internal( +static void llm_load_llama( llama_model_loader & ml, llama_model & model, int n_batch, @@ -1609,9 +1730,9 @@ static void llama_model_load_internal( const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; - const auto llm = LLM(); + const auto tn = LLM_TN(LLM_ARCH_LLAMA); - model.tok_embeddings = ml.create_tensor(ctx, llm(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // "output" tensor { @@ -1632,8 +1753,8 @@ static void llama_model_load_internal( backend_output = GGML_BACKEND_CPU; } - model.norm = ml.create_tensor(ctx, llm(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, llm(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + model.norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } @@ -1652,18 +1773,18 @@ static void llama_model_load_internal( const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; - layer.attention_norm = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attention_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wq = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); - layer.wk = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, llm(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.w1 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, llm(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1783,16 +1904,10 @@ static bool llama_model_load( try { std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); - std::string arch_name; - GGUF_GET_KEY(ml->ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "general.architecture"); + const llm_arch arch = llm_load_arch(*ml); - const llm_arch arch = llm_arch_from_string(arch_name); - if (arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); - } - - llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); - llm_load_vocab(*ml, model); + llm_load_hparams(arch, *ml, model, n_ctx, rope_freq_base, rope_freq_scale); + llm_load_vocab (arch, *ml, model); llm_load_print_meta(*ml, model); @@ -1805,14 +1920,25 @@ static bool llama_model_load( return true; } - llama_model_load_internal(*ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, - use_mlock, progress_callback, progress_callback_user_data); - return true; + switch (arch) { + case LLM_ARCH_LLAMA: + { + llm_load_llama(*ml, model, n_batch, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + use_mlock, progress_callback, progress_callback_user_data); + } break; + case LLM_ARCH_FALCON: + { + } + default: + throw std::runtime_error("unsupported architecture"); + }; } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); return false; } + + return true; } static struct ggml_cgraph * llama_build_graph( @@ -3674,9 +3800,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = quantized_type; #ifdef GGML_USE_K_QUANTS // TODO: avoid hardcoded tensor names - use the TN_* constants - const auto llm = LLM(); + const auto tn = LLM_TN(LLM_ARCH_LLAMA); - if (name == llm(LLM_TENSOR_OUTPUT, "weight")) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K == 0 && ny % QK_K == 0) { @@ -3712,10 +3838,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } if (convert_incompatible_tensor) { - if (name == llm(LLM_TENSOR_OUTPUT, "weight")) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); - } else if (name == llm(LLM_TENSOR_TOKEN_EMBD, "weight")) { + } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); } else { From 9f28f73785ecf62fe2625809b09348f8c2dd7625 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 20:34:17 +0300 Subject: [PATCH 05/36] llm : read arch-specific KVs --- llama.cpp | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index b2c9b339602b0..5c0bf619053ae 100644 --- a/llama.cpp +++ b/llama.cpp @@ -344,6 +344,7 @@ struct LLM_TN { return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; } }; + // // gguf helpers // @@ -1497,12 +1498,12 @@ static void llm_load_hparams( int n_ctx, float rope_freq_base, float rope_freq_scale) { - auto & hparams = model.hparams; - struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(arch); + auto & hparams = model.hparams; + // get general kv GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); @@ -1514,8 +1515,6 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT)); - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; @@ -1523,12 +1522,37 @@ static void llm_load_hparams( // TODO: manually setting rope scale should override this // rope_freq_scale (inverse of the kv) is optional - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - if (ropescale != 1.0f) { - rope_freq_scale = 1.0f/ropescale; + { + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (ropescale != 1.0f) { + rope_freq_scale = 1.0f/ropescale; + } + } + + // sanity check for n_rot (optional) + { + hparams.n_rot = hparams.n_embd / hparams.n_head; + + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + + if (hparams.n_rot != hparams.n_embd / hparams.n_head) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + } } + // arch-specific KVs + switch (arch) { + case LLM_ARCH_LLAMA: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + } break; + case LLM_ARCH_FALCON: + { + } break; + default: (void)0; + }; + // TODO: generalize to non-LLaMA models switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1594,6 +1618,7 @@ static void llm_load_vocab( // determine vocab type { std::string tokenizer_name; + GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); if (tokenizer_name == "llama") { From d1b3b95dc4d21138f6941350bdd2a859ac157646 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 20:55:05 +0300 Subject: [PATCH 06/36] convert : add dummy scores + types --- convert-falcon-hf-to-gguf.py | 6 ++++++ llama.cpp | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index b3e190a0fd83c..32420796b3a21 100644 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -109,6 +109,8 @@ def count_model_parts(dir_model: str) -> int: print("gguf: get tokenizer metadata") tokens: List[str] = [] +scores: List[float] = [] +toktypes: List[int] = [] merges: List[str] = [] @@ -152,8 +154,12 @@ def count_model_parts(dir_model: str) -> int: text = bytearray(pad_token) tokens.append(text) + scores.append(0.0) # dymmy + toktypes.append(gguf.TokenType.NORMAL) # dummy gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): print("gguf: get special token ids") diff --git a/llama.cpp b/llama.cpp index 5c0bf619053ae..61e989810bef8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1968,7 +1968,7 @@ static bool llama_model_load( } break; case LLM_ARCH_FALCON: { - } + } break; default: throw std::runtime_error("unsupported architecture"); }; From 2f3c80a8456521a8857c0f9a05ba6ba3437dddf0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 21:42:12 +0300 Subject: [PATCH 07/36] falcon : load tensor data (CPU only) --- convert-falcon-hf-to-gguf.py | 5 +- llama.cpp | 185 +++++++++++++++++++++++------------ 2 files changed, 123 insertions(+), 67 deletions(-) diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index 32420796b3a21..1a57b1b1ee319 100644 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -200,8 +200,9 @@ def count_model_parts(dir_model: str) -> int: tensor_map = gguf.get_tensor_name_map(ARCH,block_count) # params for qkv transform -n_head = hparams["n_head"] -n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 +n_head = hparams["n_head"] +n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else n_head + head_dim = hparams["hidden_size"] // n_head # tensor info diff --git a/llama.cpp b/llama.cpp index 61e989810bef8..c1b6576da9112 100644 --- a/llama.cpp +++ b/llama.cpp @@ -863,21 +863,25 @@ struct llama_hparams { struct llama_layer { // normalization - struct ggml_tensor * attention_norm; + struct ggml_tensor * attn_norm; + struct ggml_tensor * attn_norm_b; + struct ggml_tensor * attn_norm_2; + struct ggml_tensor * attn_norm_2_b; // attention struct ggml_tensor * wq; struct ggml_tensor * wk; struct ggml_tensor * wv; struct ggml_tensor * wo; + struct ggml_tensor * wqkv; // normalization struct ggml_tensor * ffn_norm; // ff - struct ggml_tensor * w1; - struct ggml_tensor * w2; - struct ggml_tensor * w3; + struct ggml_tensor * w1; // ffn_gate + struct ggml_tensor * w2; // ffn_down + struct ggml_tensor * w3; // ffn_up }; struct llama_kv_cache { @@ -944,10 +948,12 @@ struct llama_model { struct ggml_tensor * tok_embeddings; - struct ggml_tensor * norm; + struct ggml_tensor * output_norm; + struct ggml_tensor * output_norm_b; struct ggml_tensor * output; std::vector layers; + int n_gpu_layers; // context @@ -1117,11 +1123,11 @@ static const char * llama_file_version_name(llama_file_version version) { return "unknown"; } -static std::string llama_format_tensor_shape(const std::vector & ne) { +static std::string llama_format_tensor_shape(const std::vector & ne) { char buf[256]; - snprintf(buf, sizeof(buf), "%5u", ne.at(0)); + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5u", ne.at(i)); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); } return buf; } @@ -1301,7 +1307,7 @@ struct llama_model_loader { return tensor; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend backend) { + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend backend) { struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); if (cur == NULL) { @@ -1698,6 +1704,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } static void llm_load_llama( + llm_arch arch, llama_model_loader & ml, llama_model & model, int n_batch, @@ -1764,74 +1771,117 @@ static void llm_load_llama( // prepare memory for the weights size_t vram_weights = 0; { - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_embd_gqa = hparams.n_embd_gqa(); - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_layer = hparams.n_layer; + const int64_t n_vocab = hparams.n_vocab; - const auto tn = LLM_TN(LLM_ARCH_LLAMA); + const auto tn = LLM_TN(arch); - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + switch (arch) { + case LLM_ARCH_LLAMA: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - // "output" tensor - { - ggml_backend backend_norm; - ggml_backend backend_output; - if (n_gpu_layers > int(n_layer)) { // NOLINT - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU + // output + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } - model.norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - const uint32_t n_ff = hparams.n_ff; + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } - const int i_gpu_start = n_layer - n_gpu_layers; + const uint32_t n_ff = hparams.n_ff; - model.layers.resize(n_layer); - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + const int i_gpu_start = n_layer - n_gpu_layers; - auto & layer = model.layers[i]; - layer.attention_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + model.layers.resize(n_layer); - layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); - layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + auto & layer = model.layers[i]; - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); - } - } + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } + } + } break; + case LLM_ARCH_FALCON: + { + // TODO: CPU-only for now + + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + } + + const uint32_t n_ff = hparams.n_ff; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, GGML_BACKEND_CPU); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, GGML_BACKEND_CPU); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, GGML_BACKEND_CPU); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, GGML_BACKEND_CPU); + } + } break; + default: + throw std::runtime_error("unknown architecture"); + }; } ml.done_getting_tensors(); @@ -1962,12 +2012,17 @@ static bool llama_model_load( switch (arch) { case LLM_ARCH_LLAMA: { - llm_load_llama(*ml, model, n_batch, n_gpu_layers, + llm_load_llama( + arch, *ml, model, n_batch, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, use_mlock, progress_callback, progress_callback_user_data); } break; case LLM_ARCH_FALCON: { + llm_load_llama( + arch, *ml, model, n_batch, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + use_mlock, progress_callback, progress_callback_user_data); } break; default: throw std::runtime_error("unsupported architecture"); @@ -2105,8 +2160,8 @@ static struct ggml_cgraph * llama_build_graph( offload_func(cur); ggml_set_name(cur, "rms_norm_0"); - // cur = cur*attention_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); + // cur = cur*attn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); offload_func(cur); ggml_set_name(cur, "attention_norm_0"); } @@ -2297,7 +2352,7 @@ static struct ggml_cgraph * llama_build_graph( ggml_set_name(cur, "rms_norm_2"); // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.norm); + cur = ggml_mul(ctx0, cur, model.output_norm); // offload_func_nr(cur); // TODO CPU + GPU mirrored backend ggml_set_name(cur, "result_norm"); } From 5c5413dc145f945d7c7e88b0c5ec71625f42e4bd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 21:53:36 +0300 Subject: [PATCH 08/36] llama : fix loading progress bar --- llama.cpp | 59 ++++++++++++++++++++++--------------------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/llama.cpp b/llama.cpp index c1b6576da9112..d4d5984f8cd1a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1703,7 +1703,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } } -static void llm_load_llama( +static void llm_load_tensors( llm_arch arch, llama_model_loader & ml, llama_model & model, @@ -2009,24 +2009,10 @@ static bool llama_model_load( return true; } - switch (arch) { - case LLM_ARCH_LLAMA: - { - llm_load_llama( - arch, *ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, - use_mlock, progress_callback, progress_callback_user_data); - } break; - case LLM_ARCH_FALCON: - { - llm_load_llama( - arch, *ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, - use_mlock, progress_callback, progress_callback_user_data); - } break; - default: - throw std::runtime_error("unsupported architecture"); - }; + llm_load_tensors( + arch, *ml, model, n_batch, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + use_mlock, progress_callback, progress_callback_user_data); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); return false; @@ -4446,6 +4432,22 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_INFO("."); + if (percentage >= 100) { + LLAMA_LOG_INFO("\n"); + } + } + }; + } + if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, @@ -4476,22 +4478,6 @@ struct llama_context * llama_new_context_with_model( params.seed = time(NULL); } - unsigned cur_percentage = 0; - if (params.progress_callback == NULL) { - params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void * ctx) { - unsigned * cur_percentage_p = (unsigned *) ctx; - unsigned percentage = (unsigned) (100 * progress); - while (percentage > *cur_percentage_p) { - *cur_percentage_p = percentage; - LLAMA_LOG_INFO("."); - if (percentage >= 100) { - LLAMA_LOG_INFO("\n"); - } - } - }; - } - ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; @@ -4629,13 +4615,14 @@ struct llama_context * llama_new_context_with_model( struct llama_context * llama_init_from_file( const char * path_model, struct llama_context_params params) { - struct llama_model * model = llama_load_model_from_file(path_model, params); if (!model) { return nullptr; } + struct llama_context * ctx = llama_new_context_with_model(model, params); ctx->model_owner = true; + return ctx; } From 085228e1f5c8886d6a09d78235f849a3c5ef8de4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 22:09:56 +0300 Subject: [PATCH 09/36] llama : add arch member to llama_model --- llama.cpp | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/llama.cpp b/llama.cpp index d4d5984f8cd1a..e19f46a88817f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -938,10 +938,10 @@ struct llama_vocab { struct llama_model { e_model type = MODEL_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; llama_ftype ftype = LLAMA_FTYPE_ALL_F32; std::string name = "n/a"; - std::string arch = "n/a"; llama_hparams hparams; llama_vocab vocab; @@ -1481,7 +1481,7 @@ static const char * llama_model_type_name(e_model type) { } } -static llm_arch llm_load_arch(llama_model_loader & ml) { +static void llm_load_arch(llama_model_loader & ml, llama_model & model) { struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); @@ -1489,16 +1489,13 @@ static llm_arch llm_load_arch(llama_model_loader & ml) { std::string arch_name; GGUF_GET_KEY(ctx, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_GENERAL_ARCHITECTURE)); - const llm_arch arch = llm_arch_from_string(arch_name); - if (arch == LLM_ARCH_UNKNOWN) { + model.arch = llm_arch_from_string(arch_name); + if (model.arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); } - - return arch; } static void llm_load_hparams( - llm_arch arch, llama_model_loader & ml, llama_model & model, int n_ctx, @@ -1506,13 +1503,12 @@ static void llm_load_hparams( float rope_freq_scale) { struct gguf_context * ctx = ml.ctx_gguf; - const auto kv = LLM_KV(arch); + const auto kv = LLM_KV(model.arch); auto & hparams = model.hparams; // get general kv GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); - GGUF_GET_KEY(ctx, model.arch, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); // get hparams kv GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); @@ -1548,7 +1544,7 @@ static void llm_load_hparams( } // arch-specific KVs - switch (arch) { + switch (model.arch) { case LLM_ARCH_LLAMA: { GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); @@ -1593,14 +1589,13 @@ static void llm_load_hparams( } static void llm_load_vocab( - llm_arch arch, llama_model_loader & ml, llama_model & model) { auto & vocab = model.vocab; struct gguf_context * ctx = ml.ctx_gguf; - const auto kv = LLM_KV(arch); + const auto kv = LLM_KV(model.arch); const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); if (token_idx == -1) { @@ -1672,7 +1667,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { // hparams LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, model.arch.c_str()); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); @@ -1704,7 +1699,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { } static void llm_load_tensors( - llm_arch arch, llama_model_loader & ml, llama_model & model, int n_batch, @@ -1776,9 +1770,9 @@ static void llm_load_tensors( const int64_t n_layer = hparams.n_layer; const int64_t n_vocab = hparams.n_vocab; - const auto tn = LLM_TN(arch); + const auto tn = LLM_TN(model.arch); - switch (arch) { + switch (model.arch) { case LLM_ARCH_LLAMA: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -1993,10 +1987,9 @@ static bool llama_model_load( try { std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); - const llm_arch arch = llm_load_arch(*ml); - - llm_load_hparams(arch, *ml, model, n_ctx, rope_freq_base, rope_freq_scale); - llm_load_vocab (arch, *ml, model); + llm_load_arch (*ml, model); + llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); + llm_load_vocab (*ml, model); llm_load_print_meta(*ml, model); @@ -2010,7 +2003,7 @@ static bool llama_model_load( } llm_load_tensors( - arch, *ml, model, n_batch, n_gpu_layers, + *ml, model, n_batch, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, use_mlock, progress_callback, progress_callback_user_data); } catch (const std::exception & err) { From 3c7c325b9867a30637529b0328fbc73b4e527004 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 22:31:39 +0300 Subject: [PATCH 10/36] falcon : CPU inference working --- llama.cpp | 299 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 286 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index e19f46a88817f..c942b0727e407 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1031,8 +1031,6 @@ struct llama_context { // key + value cache for the self attention struct llama_kv_cache kv_self; - size_t mem_per_token = 0; - // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; bool logits_all = false; @@ -2014,7 +2012,7 @@ static bool llama_model_load( return true; } -static struct ggml_cgraph * llama_build_graph( +static struct ggml_cgraph * llm_build_llama( llama_context & lctx, const llama_token * tokens, const float * embd, @@ -2048,8 +2046,7 @@ static struct ggml_cgraph * llama_build_graph( const int n_gpu_layers = model.n_gpu_layers; - auto & mem_per_token = lctx.mem_per_token; - auto & buf_compute = lctx.buf_compute; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, @@ -2340,20 +2337,296 @@ static struct ggml_cgraph * llama_build_graph( cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); - // logits -> probs - //cur = ggml_soft_max_inplace(ctx0, cur); - ggml_build_forward_expand(gf, cur); - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_falcon( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + //const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } + } + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur; + struct ggml_tensor * layernorm_output; + + // self-attention + { + layernorm_output = ggml_norm(ctx0, inpL); + + layernorm_output = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attn_norm, layernorm_output), + layernorm_output), + ggml_repeat(ctx0, model.layers[il].attn_norm_b, layernorm_output)); + + if ( hparams.n_head_kv == 8 ) { // Falcon-40B + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur), + cur), + ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur)); + } + else { // Falcon 7B + cur = layernorm_output; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + + // Note that the strides for Kcur, Vcur are set up so that the + // resulting views are misaligned with the tensor's storage + // (by applying the K/V offset we shift the tensor's original + // view to stick out behind the viewed QKV tensor's allocated + // memory, so to say). This is ok because no actual accesses + // happen to that out-of-range memory, but it can require some + // trickery when trying to accurately dump these views for + // debugging. + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + 0); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + n_embd_head * n_head * ggml_type_size(GGML_TYPE_F32)); + + struct ggml_tensor * Vcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + n_head_kv) * ggml_type_size(GGML_TYPE_F32)); + + // using mode = 2 for neox mode + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0); + + // store key and value to memory + { + struct ggml_tensor* k = ggml_view_1d( + ctx0, kv_self.k, N * n_head_kv * n_embd_head, + (ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) * + (il * n_ctx + n_past)); + struct ggml_tensor* v = ggml_view_1d( + ctx0, kv_self.v, N * n_head_kv * n_embd_head, + (ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) * + (il * n_ctx + n_past)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * K = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head, + il * n_ctx * + ggml_element_size(kv_self.k) * + n_head_kv * + n_embd_head), + n_embd_head, n_head_kv, n_past + N), + 0, 2, 1, 3); + + // K * Q + +// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy)); + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd_head))) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor* V = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head, + il * n_ctx * + ggml_element_size(kv_self.v) * + n_head_kv * + n_embd_head), + n_embd_head, n_head_kv, n_past + N), + 0, 2, 1, 3); + +// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy))); + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + } + } + + struct ggml_tensor* inpFF = layernorm_output; + struct ggml_tensor* attn_out = ggml_cpy( + ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + { + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + cur = ggml_add(ctx0, cur, inpL); + // input for next layer + inpL = cur; + } + + // norm + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.output_norm, cur), + cur), + ggml_repeat(ctx0, model.output_norm_b, cur)); + ggml_set_name(cur, "result_norm"); } + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); return gf; } +static struct ggml_cgraph * llama_build_graph( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + const auto & model = lctx.model; + + struct ggml_cgraph * result = NULL; + + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + } break; + case LLM_ARCH_FALCON: + { + result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + } break; + default: + GGML_ASSERT(false); + }; + + return result; +} + // evaluate the transformer // // - lctx: llama context @@ -2427,11 +2700,11 @@ static bool llama_eval_internal( // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); #if GGML_USE_MPI const int64_t n_layer = hparams.n_layer; From 2d58444dae1545b96de9929366c7dffe09605c5e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 22:52:14 +0300 Subject: [PATCH 11/36] falcon : support non-40B models --- convert-falcon-hf-to-gguf.py | 7 +++++-- llama.cpp | 40 +++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index 1a57b1b1ee319..8d879c499d18d 100644 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -101,7 +101,10 @@ def count_model_parts(dir_model: str) -> int: gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"]) gguf_writer.add_block_count(block_count) gguf_writer.add_head_count(hparams["n_head"]) -if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"]) +if "n_head_kv" in hparams: + gguf_writer.add_head_count_kv(hparams["n_head_kv"]) +else: + gguf_writer.add_head_count_kv(1) gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) # TOKENIZATION @@ -201,7 +204,7 @@ def count_model_parts(dir_model: str) -> int: # params for qkv transform n_head = hparams["n_head"] -n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else n_head +n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 head_dim = hparams["hidden_size"] // n_head diff --git a/llama.cpp b/llama.cpp index c942b0727e407..d1a6e1be2bd19 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1859,10 +1859,13 @@ static void llm_load_tensors( for (uint32_t i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU); - layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU); - layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU); + + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU); + } layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, GGML_BACKEND_CPU); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, GGML_BACKEND_CPU); @@ -2421,19 +2424,19 @@ static struct ggml_cgraph * llm_build_falcon( for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur; - struct ggml_tensor * layernorm_output; + struct ggml_tensor * attn_norm; // self-attention { - layernorm_output = ggml_norm(ctx0, inpL); + attn_norm = ggml_norm(ctx0, inpL); - layernorm_output = ggml_add(ctx0, + attn_norm = ggml_add(ctx0, ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attn_norm, layernorm_output), - layernorm_output), - ggml_repeat(ctx0, model.layers[il].attn_norm_b, layernorm_output)); + ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm), + attn_norm), + ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); - if ( hparams.n_head_kv == 8 ) { // Falcon-40B + if (hparams.n_head_kv == 8) { // Falcon-40B cur = ggml_norm(ctx0, inpL); cur = ggml_add(ctx0, @@ -2441,9 +2444,8 @@ static struct ggml_cgraph * llm_build_falcon( ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur), cur), ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur)); - } - else { // Falcon 7B - cur = layernorm_output; + } else { // Falcon 7B + cur = attn_norm; } // compute QKV @@ -2563,8 +2565,8 @@ static struct ggml_cgraph * llm_build_falcon( } } - struct ggml_tensor* inpFF = layernorm_output; - struct ggml_tensor* attn_out = ggml_cpy( + struct ggml_tensor * inpFF = attn_norm; + struct ggml_tensor * attn_out = ggml_cpy( ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); { @@ -2607,7 +2609,7 @@ static struct ggml_cgraph * llama_build_graph( const float * embd, int n_tokens, int n_past) { - const auto & model = lctx.model; + const auto & model = lctx.model; struct ggml_cgraph * result = NULL; @@ -2669,8 +2671,8 @@ static bool llama_eval_internal( GGML_ASSERT(!!kv_self.ctx); - const int64_t n_embd = hparams.n_embd; - const int64_t n_vocab = hparams.n_vocab; + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = hparams.n_vocab; ggml_allocr_reset(lctx.alloc); From 0ec27ad66c56a4831f62a8106b7873a6818bf051 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 23:11:41 +0300 Subject: [PATCH 12/36] falcon : minor --- ggml-alloc.c | 4 ++-- ggml-alloc.h | 2 +- llama.cpp | 26 ++++++++++++-------------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/ggml-alloc.c b/ggml-alloc.c index f06f9a3c1d97b..547ec0399fdb5 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t alloc->n_free_blocks++; } -void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) { +void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) { int pos = 0; for (int i = 0; i < n; i++) { if (list[i] != -1) { @@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n( struct ggml_tensor * view_src = get_view_source(parent); struct hash_node * view_src_hn = hash_get(ht, view_src); view_src_hn->n_views -= 1; - AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views); + AT_PRINTF("view_src %s\n", view_src->name); if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { ggml_allocator_free_tensor(alloc, view_src); } diff --git a/ggml-alloc.h b/ggml-alloc.h index 14a4350ac2e96..9559da75871a6 100644 --- a/ggml-alloc.h +++ b/ggml-alloc.h @@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); // tell the allocator to parse nodes following the order described in the list // you should call this if your graph are optimized to execute out-of-order -GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n); +GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n); GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); diff --git a/llama.cpp b/llama.cpp index d1a6e1be2bd19..b65bf9461f0eb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2436,7 +2436,7 @@ static struct ggml_cgraph * llm_build_falcon( attn_norm), ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); - if (hparams.n_head_kv == 8) { // Falcon-40B + if (model.layers[il].attn_norm_2) { // Falcon-40B cur = ggml_norm(ctx0, inpL); cur = ggml_add(ctx0, @@ -2461,23 +2461,25 @@ static struct ggml_cgraph * llm_build_falcon( // trickery when trying to accurately dump these views for // debugging. + const size_t wsize = ggml_type_size(cur->type); + struct ggml_tensor * Qcur = ggml_view_3d( ctx0, cur, n_embd_head, n_head, N, - n_embd_head * ggml_type_size(GGML_TYPE_F32), - n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), 0); struct ggml_tensor * Kcur = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, - n_embd_head * ggml_type_size(GGML_TYPE_F32), - n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), - n_embd_head * n_head * ggml_type_size(GGML_TYPE_F32)); + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head); struct ggml_tensor * Vcur = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, - n_embd_head * ggml_type_size(GGML_TYPE_F32), - n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), - n_embd_head * (n_head + n_head_kv) * ggml_type_size(GGML_TYPE_F32)); + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); // using mode = 2 for neox mode Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0); @@ -2518,11 +2520,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd_head))) - ); + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); From 7bbbf38c32f0b9492f79da7a3dff0e3a5f431538 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 23:26:16 +0300 Subject: [PATCH 13/36] llama : minor updates ggml-ci --- llama.cpp | 78 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/llama.cpp b/llama.cpp index b65bf9461f0eb..ed4d6b366447d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1109,11 +1109,11 @@ static bool llama_kv_cache_init( // model loading and saving // -enum llama_file_version { +enum llama_fver { GGUF_FILE_VERSION_V1 = 1, }; -static const char * llama_file_version_name(llama_file_version version) { +static const char * llama_file_version_name(llama_fver version) { switch (version) { case GGUF_FILE_VERSION_V1: return "GGUF V1 (latest)"; } @@ -1148,9 +1148,9 @@ struct llama_model_loader { bool use_mmap = false; - llama_file file; + llama_file file; llama_ftype ftype; - llama_file_version fver; + llama_fver fver; std::unique_ptr mapping; @@ -1171,7 +1171,7 @@ struct llama_model_loader { n_kv = gguf_get_n_kv(ctx_gguf); n_tensors = gguf_get_n_tensors(ctx_gguf); - fver = (enum llama_file_version) gguf_get_version(ctx_gguf); + fver = (enum llama_fver ) gguf_get_version(ctx_gguf); for (int i = 0; i < n_tensors; i++) { const char * name = gguf_get_tensor_name(ctx_gguf, i); @@ -1268,6 +1268,21 @@ struct llama_model_loader { } } + std::string get_arch_name() const { + const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); + + std::string arch_name; + GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); + + return arch_name; + } + + enum llm_arch get_arch() const { + const std::string arch_name = get_arch_name(); + + return llm_arch_from_string(arch_name); + } + const char * get_tensor_name(int i) const { return gguf_get_tensor_name(ctx_gguf, i); } @@ -1480,16 +1495,9 @@ static const char * llama_model_type_name(e_model type) { } static void llm_load_arch(llama_model_loader & ml, llama_model & model) { - struct gguf_context * ctx = ml.ctx_gguf; - - const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); - - std::string arch_name; - GGUF_GET_KEY(ctx, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_GENERAL_ARCHITECTURE)); - - model.arch = llm_arch_from_string(arch_name); + model.arch = ml.get_arch(); if (model.arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + arch_name + "'"); + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } } @@ -4048,13 +4056,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + std::unique_ptr ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); const size_t align = GGUF_DEFAULT_ALIGNMENT; struct gguf_context * ctx_out = gguf_init_empty(); // copy the KV pairs from the input file - gguf_set_kv (ctx_out, model_loader->ctx_gguf); + gguf_set_kv (ctx_out, ml->ctx_gguf); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); gguf_set_val_u32(ctx_out, "general.file_type", ftype); @@ -4062,8 +4070,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (int i = 0; i < model_loader->n_tensors; ++i) { - struct ggml_tensor * meta = model_loader->get_tensor_meta(i); + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * meta = ml->get_tensor_meta(i); const std::string name = ggml_get_name(meta); @@ -4097,8 +4105,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector work; // populate the original tensors so we get an initial meta data - for (int i = 0; i < model_loader->n_tensors; ++i) { - struct ggml_tensor * meta = model_loader->get_tensor_meta(i); + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * meta = ml->get_tensor_meta(i); gguf_add_tensor(ctx_out, meta); } @@ -4111,17 +4119,17 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - for (int i = 0; i < model_loader->n_tensors; ++i) { - struct ggml_tensor * tensor = model_loader->get_tensor_meta(i); + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * tensor = ml->get_tensor_meta(i); const std::string name = ggml_get_name(tensor); read_data.resize(ggml_nbytes(tensor)); tensor->data = read_data.data(); - model_loader->load_data_for(tensor); + ml->load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, model_loader->n_tensors, + ++idx, ml->n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); @@ -4147,7 +4155,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = quantized_type; #ifdef GGML_USE_K_QUANTS // TODO: avoid hardcoded tensor names - use the TN_* constants - const auto tn = LLM_TN(LLM_ARCH_LLAMA); + const auto tn = LLM_TN(ml->get_arch()); if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { int nx = tensor->ne[0]; @@ -4386,28 +4394,28 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const } // load base model - std::unique_ptr model_loader; + std::unique_ptr ml; ggml_context * base_ctx = NULL; std::vector base_buf; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); size_t ctx_size; size_t mmapped_size; - model_loader->calc_sizes(ctx_size, mmapped_size); + ml->calc_sizes(ctx_size, mmapped_size); base_buf.resize(ctx_size); ggml_init_params base_params; base_params.mem_size = base_buf.size(); base_params.mem_buffer = base_buf.data(); - base_params.no_alloc = model_loader->use_mmap; + base_params.no_alloc = ml->use_mmap; base_ctx = ggml_init(base_params); // maybe this should in llama_model_loader - if (model_loader->use_mmap) { - model_loader->mapping.reset(new llama_mmap(&model_loader->file, /* prefetch */ 0, ggml_is_numa())); + if (ml->use_mmap) { + ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa())); } } @@ -4511,8 +4519,8 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const #endif // GGML_USE_CUBLAS ggml_tensor * base_t; - if (model_loader) { - struct gguf_context * ctx_gguf = model_loader->ctx_gguf; + if (ml) { + struct gguf_context * ctx_gguf = ml->ctx_gguf; // load from base model if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { @@ -4522,8 +4530,8 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const } // TODO: not tested!! maybe not working! - base_t = model_loader->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); - model_loader->load_data_for(base_t); + base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); + ml->load_data_for(base_t); } else { base_t = dest_t; } From 9853f2cfb26bdbf71aa475fa848a7555286dec4a Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Tue, 22 Aug 2023 22:29:11 +0200 Subject: [PATCH 14/36] convert-falcon-hf-to-gguf.py : fix special token mapping --- convert-falcon-hf-to-gguf.py | 49 +++++++++++++----------------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index 8d879c499d18d..608f8c455e0df 100644 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -164,38 +164,23 @@ def count_model_parts(dir_model: str) -> int: gguf_writer.add_token_scores(scores) gguf_writer.add_token_types(toktypes) - if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): - print("gguf: get special token ids") - - with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: - tokenizer_config = json.load(f) - - # find special token ids - - if "bos_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["bos_token"]: - gguf_writer.add_bos_token_id(key["id"]) - - if "eos_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["eos_token"]: - gguf_writer.add_eos_token_id(key["id"]) - - if "unk_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["unk_token"]: - gguf_writer.add_unk_token_id(key["id"]) - - if "sep_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["sep_token"]: - gguf_writer.add_sep_token_id(key["id"]) - - if "pad_token" in tokenizer_config: - for key in tokenizer_json["added_tokens"]: - if key["content"] == tokenizer_config["pad_token"]: - gguf_writer.add_pad_token_id(key["id"]) +print("gguf: get special token ids") +# Look for special tokens in config.json + +if "bos_token_id" in hparams and hparams["bos_token_id"] != None: + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) + +if "eos_token_id" in hparams and hparams["eos_token_id"] != None: + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) + +if "unk_token_id" in hparams and hparams["unk_token_id"] != None: + gguf_writer.add_unk_token_id(hparams["unk_token_id"]) + +if "sep_token_id" in hparams and hparams["sep_token_id"] != None: + gguf_writer.add_sep_token_id(hparams["sep_token_id"]) + +if "pad_token_id" in hparams and hparams["pad_token_id"] != None: + gguf_writer.add_pad_token_id(hparams["pad_token_id"]) # TENSORS From ffa5099c6d19bf9e783e545ecd3040b75d06c06c Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Tue, 22 Aug 2023 22:34:03 +0200 Subject: [PATCH 15/36] llama.cpp : llama default UNK token = id 0 --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ed4d6b366447d..4b539516c3c21 100644 --- a/llama.cpp +++ b/llama.cpp @@ -929,7 +929,7 @@ struct llama_vocab { // default LLaMA special tokens id special_bos_id = 1; id special_eos_id = 2; - id special_unk_id = -1; + id special_unk_id = 0; id special_sep_id = -1; id special_pad_id = -1; From a95ae7526a63d496c7aebff3787e9720ce837298 Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Wed, 23 Aug 2023 00:02:13 +0200 Subject: [PATCH 16/36] llama.cpp : fix bpe tokenizer --- llama.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4b539516c3c21..61e7341791416 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2839,10 +2839,14 @@ static bool llama_is_pad_token(const llama_vocab & vocab, llama_token id ) { } static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto& token_data = vocab.id_to_token.at(id); - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); + if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); + } + // vocab.type == LLAMA_VOCAB_TYPE_BPE + return vocab.token_to_id.at(std::string(1, ch)); } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { From d561b7f7244af50ea5ce05fe05a1ea71e94859ad Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Wed, 23 Aug 2023 00:06:53 +0200 Subject: [PATCH 17/36] llama.cpp : fix the fix of bpe tokenizer --- llama.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 61e7341791416..17681a8a571fe 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2839,6 +2839,13 @@ static bool llama_is_pad_token(const llama_vocab & vocab, llama_token id ) { } static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto& token_data = vocab.id_to_token.at(id); + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); +} + +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { char buf[7]; int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); @@ -2849,13 +2856,6 @@ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { return vocab.token_to_id.at(std::string(1, ch)); } -static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - GGML_ASSERT(0 <= result && result < 7); - return vocab.token_to_id.at(buf); -} - static std::string llama_escape_whitespace(const std::string& text) { std::string result; bool escaping = false; From e3c52bd9905599b6fd0cba9393b399e4dd37b252 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 10:40:58 +0300 Subject: [PATCH 18/36] ggml : pass eps to ggml_norm --- ggml-metal.m | 3 ++- ggml.c | 16 ++++++++----- ggml.h | 7 +++--- llama.cpp | 63 +++++++++++++++++++++++++++------------------------- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 835c5f297cf95..a133b2236117d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -938,7 +938,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_NORM: { - const float eps = 1e-5f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 256; diff --git a/ggml.c b/ggml.c index dffb977313584..2180ea0e2dda1 100644 --- a/ggml.c +++ b/ggml.c @@ -5789,6 +5789,7 @@ struct ggml_tensor * ggml_silu_back( static struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -5799,7 +5800,7 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - // TODO: maybe store epsilon here? + ggml_set_op_params(result, &eps, sizeof(eps)); result->op = GGML_OP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5810,14 +5811,16 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * ggml_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, true); } // ggml_rms_norm @@ -10619,7 +10622,8 @@ static void ggml_compute_forward_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-5f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { diff --git a/ggml.h b/ggml.h index 3c48fd27fab39..421c0df60c579 100644 --- a/ggml.h +++ b/ggml.h @@ -909,14 +909,15 @@ extern "C" { struct ggml_tensor * b); // normalize along rows - // TODO: eps is hardcoded to 1e-5 for now GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, diff --git a/llama.cpp b/llama.cpp index 17681a8a571fe..e801eb333153b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -830,6 +830,7 @@ struct llama_hparams { uint32_t n_rot = 64; uint32_t n_ff = 11008; + float f_norm_eps = 1e-5; float f_norm_rms_eps = 1e-5; float rope_freq_base = 10000.0f; @@ -1557,6 +1558,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_FALCON: { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); } break; default: (void)0; }; @@ -1672,28 +1674,29 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { const auto & vocab = model.vocab; // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); // special tokens if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } @@ -1899,8 +1902,7 @@ static void llm_load_tensors( mmapped_size - vram_weights; // weights in VRAM not in memory // this is the memory required by one llama_state - const size_t mem_required_state = - scale*hparams.kv_size(); + const size_t mem_required_state = scale*hparams.kv_size(); LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -2383,6 +2385,10 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(n_embd_head == hparams.n_rot); + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float norm_eps = hparams.f_norm_eps; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2436,7 +2442,7 @@ static struct ggml_cgraph * llm_build_falcon( // self-attention { - attn_norm = ggml_norm(ctx0, inpL); + attn_norm = ggml_norm(ctx0, inpL, norm_eps); attn_norm = ggml_add(ctx0, ggml_mul(ctx0, @@ -2445,7 +2451,7 @@ static struct ggml_cgraph * llm_build_falcon( ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); if (model.layers[il].attn_norm_2) { // Falcon-40B - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -2490,8 +2496,8 @@ static struct ggml_cgraph * llm_build_falcon( wsize * n_embd_head * (n_head + n_head_kv)); // using mode = 2 for neox mode - Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0); - Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0); + Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale); // store key and value to memory { @@ -2522,8 +2528,6 @@ static struct ggml_cgraph * llm_build_falcon( // K * Q -// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy)); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2549,7 +2553,6 @@ static struct ggml_cgraph * llm_build_falcon( n_embd_head, n_head_kv, n_past + N), 0, 2, 1, 3); -// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy))); V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); // KQV = transpose(V) * KQ_soft_max @@ -2589,7 +2592,7 @@ static struct ggml_cgraph * llm_build_falcon( // norm { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, From 99bb26078f6baec8bcb995bf09c61dee8b0d4c75 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 10:41:35 +0300 Subject: [PATCH 19/36] metal : implement RoPE (mode = 2) + avoid ggml_repeat --- ggml-metal.metal | 20 +++++++++++++++++++- llama.cpp | 18 ++++++------------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ce3541f4bb55f..53604a250af46 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -571,7 +571,25 @@ kernel void kernel_rope( dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: implement + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } } } diff --git a/llama.cpp b/llama.cpp index e801eb333153b..3268524293a63 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2445,19 +2445,15 @@ static struct ggml_cgraph * llm_build_falcon( attn_norm = ggml_norm(ctx0, inpL, norm_eps); attn_norm = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm), - attn_norm), - ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); + ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), + model.layers[il].attn_norm_b); if (model.layers[il].attn_norm_2) { // Falcon-40B cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur), - cur), - ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur)); + ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), + model.layers[il].attn_norm_2_b); } else { // Falcon 7B cur = attn_norm; } @@ -2595,10 +2591,8 @@ static struct ggml_cgraph * llm_build_falcon( cur = ggml_norm(ctx0, inpL, norm_eps); cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.output_norm, cur), - cur), - ggml_repeat(ctx0, model.output_norm_b, cur)); + ggml_mul(ctx0, cur, model.output_norm), + model.output_norm_b); ggml_set_name(cur, "result_norm"); } From af4bbcc87387884d3e0974a1992f9dfb76a2160d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 10:42:02 +0300 Subject: [PATCH 20/36] ggml : ggml_repeat always creates new tensor --- ggml.c | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ggml.c b/ggml.c index 2180ea0e2dda1..c7cc72d48c40d 100644 --- a/ggml.c +++ b/ggml.c @@ -5555,10 +5555,6 @@ struct ggml_tensor * ggml_repeat( is_node = true; } - if (ggml_are_same_shape(a, b) && !is_node) { - return a; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); result->op = GGML_OP_REPEAT; From b34ab74094b979dd9e4b719e938721bdc3eba783 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 11:04:26 +0300 Subject: [PATCH 21/36] falcon : copy-paste self-attention from LLaMA --- llama.cpp | 117 ++++++++++++++++++++++++------------------------------ 1 file changed, 52 insertions(+), 65 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3268524293a63..38593ac018988 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2201,10 +2201,7 @@ static struct ggml_cgraph * llm_build_llama( ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); offload_func_kq(Q); ggml_set_name(Q, "Q"); @@ -2381,7 +2378,7 @@ static struct ggml_cgraph * llm_build_falcon( const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); - //const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -2441,6 +2438,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * attn_norm; // self-attention + // TODO: refactor into common function (shared with LLaMA) { attn_norm = ggml_norm(ctx0, inpL, norm_eps); @@ -2473,115 +2471,104 @@ static struct ggml_cgraph * llm_build_falcon( const size_t wsize = ggml_type_size(cur->type); - struct ggml_tensor * Qcur = ggml_view_3d( + struct ggml_tensor * tmpq = ggml_view_3d( ctx0, cur, n_embd_head, n_head, N, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), 0); - struct ggml_tensor * Kcur = ggml_view_3d( + struct ggml_tensor * tmpk = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * n_head); + wsize * n_embd_head * n_head); - struct ggml_tensor * Vcur = ggml_view_3d( + struct ggml_tensor * tmpv = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * (n_head + n_head_kv)); // using mode = 2 for neox mode - Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale); - Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); - // store key and value to memory { - struct ggml_tensor* k = ggml_view_1d( - ctx0, kv_self.k, N * n_head_kv * n_embd_head, - (ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) * - (il * n_ctx + n_past)); - struct ggml_tensor* v = ggml_view_1d( - ctx0, kv_self.v, N * n_head_kv * n_embd_head, - (ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) * - (il * n_ctx + n_past)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } - struct ggml_tensor * K = ggml_permute( - ctx0, - ggml_reshape_3d( - ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head, - il * n_ctx * - ggml_element_size(kv_self.k) * - n_head_kv * - n_embd_head), - n_embd_head, n_head_kv, n_past + N), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + ggml_set_name(Q, "Q"); - // K * Q + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd/n_head) struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); - // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + ggml_set_name(KQ_masked, "KQ_masked"); - // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - struct ggml_tensor* V = ggml_permute( - ctx0, - ggml_reshape_3d( - ctx0, - ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head, - il * n_ctx * - ggml_element_size(kv_self.v) * - n_head_kv * - n_embd_head), - n_embd_head, n_head_kv, n_past + N), - 0, 2, 1, 3); - - V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - - // KQV = transpose(V) * KQ_soft_max struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); - // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + ggml_set_name(cur, "KQV_merged_contiguous"); - // projection - { - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - } + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + ggml_set_name(cur, "result_wo"); } struct ggml_tensor * inpFF = attn_norm; struct ggml_tensor * attn_out = ggml_cpy( ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - { - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - } + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); cur = ggml_add(ctx0, cur, attn_out); cur = ggml_add(ctx0, cur, inpL); + // input for next layer inpL = cur; } From a0dc47a50154518f8e18c8fc9e2a731506bb8d2d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 11:25:26 +0300 Subject: [PATCH 22/36] metal : print extra compute pipeline info --- ggml-metal.m | 73 ++++++++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a133b2236117d..1c7930a185dfa 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -167,7 +167,9 @@ @implementation GGMLMetalClass #define GGML_METAL_ADD_KERNEL(name) \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \ + fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ + (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ + (int) ctx->pipeline_##name.threadExecutionWidth); \ if (error) { \ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -218,12 +220,12 @@ @implementation GGMLMetalClass #undef GGML_METAL_ADD_KERNEL } - fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); + fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); if (ctx->device.maxTransferRate != 0) { - fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); } else { - fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__); + fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__); } return ctx; @@ -744,32 +746,31 @@ void ggml_metal_graph_compute( [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && ne11 > 1) { - switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; - default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } - else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { int nth0 = 32; int nth1 = 1; @@ -868,24 +869,24 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; From e2d23bed1b52b740ef78e880e7e220ce6a048662 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 12:25:49 +0300 Subject: [PATCH 23/36] falcon : minor changes (still chasing the Metal problem) --- ggml-metal.m | 74 +++++++++++++++++++++++++++------------------------- ggml.c | 6 ++--- llama.cpp | 23 +++++++--------- 3 files changed, 51 insertions(+), 52 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1c7930a185dfa..1c1810da4694a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -757,17 +757,17 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -945,11 +945,11 @@ void ggml_metal_graph_compute( const int nth = 256; [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -992,7 +992,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: @@ -1007,8 +1009,8 @@ void ggml_metal_graph_compute( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; @@ -1059,24 +1061,24 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml.c b/ggml.c index c7cc72d48c40d..6839d03213559 100644 --- a/ggml.c +++ b/ggml.c @@ -3554,9 +3554,9 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); diff --git a/llama.cpp b/llama.cpp index 38593ac018988..a2ee80ab2bbb5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2545,26 +2545,23 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_set_name(cur, "KQV_merged_contiguous"); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); ggml_set_name(cur, "result_wo"); } - struct ggml_tensor * inpFF = attn_norm; - struct ggml_tensor * attn_out = ggml_cpy( - ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + struct ggml_tensor * attn_out = cur; - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + // feed forward + { + struct ggml_tensor * inpFF = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + } cur = ggml_add(ctx0, cur, attn_out); cur = ggml_add(ctx0, cur, inpL); From b693000c2e388cb02b1ab5a8e0e7df8942a76f17 Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Wed, 23 Aug 2023 13:22:41 +0200 Subject: [PATCH 24/36] llama.cpp : fix linefeed token --- llama.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index a2ee80ab2bbb5..61446cdf77c29 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1596,6 +1596,9 @@ static void llm_load_hparams( hparams.rope_freq_scale = rope_freq_scale; } +// TODO: This should probably be in llama.h +static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape); + static void llm_load_vocab( llama_model_loader & ml, llama_model & model) { @@ -1655,12 +1658,11 @@ static void llm_load_vocab( token_data.score = scores[i]; token_data.type = (llama_token_type) toktypes[i]; - // determine the newline token: 0x0A == 10 == '\n' - if (token_data.text == "<0x0A>") { - vocab.linefeed_id = i; - } } + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false, false)[0]; + // special tokens GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); From 0a85ae73977e69b0fbeb9345361ce8e358919416 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 15:04:53 +0300 Subject: [PATCH 25/36] metal : fix GELU kernel numerical stability by using precise::tanh --- ggml-metal.m | 4 ++-- ggml-metal.metal | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 1c1810da4694a..969cf7daa74c5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -539,8 +539,8 @@ void ggml_metal_graph_compute( id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); for (int ind = node_start; ind < node_end; ++ind) { const int i = has_concur ? ctx->concur_list[ind] : ind; diff --git a/ggml-metal.metal b/ggml-metal.metal index 53604a250af46..7bc3fdf371897 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -87,7 +87,12 @@ kernel void kernel_gelu( device float * dst, uint tpig[[thread_position_in_grid]]) { float x = src0[tpig]; - dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } kernel void kernel_soft_max( From 854ae5d030b4dbeebe19815c1a94c518f8ee8abd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 15:25:31 +0300 Subject: [PATCH 26/36] metal : temporary workaround for the concurrency optimization bug --- llama.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 61446cdf77c29..112843d202418 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2333,9 +2333,11 @@ static struct ggml_cgraph * llm_build_llama( inpL = cur; } + cur = inpL; + // norm { - cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); + cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2"); @@ -2436,7 +2438,6 @@ static struct ggml_cgraph * llm_build_falcon( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * cur; struct ggml_tensor * attn_norm; // self-attention @@ -2561,6 +2562,12 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * inpFF = attn_norm; cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + + // TODO: this is temporary needed to introduce artificial dependency between FF and ATTN + // adding this, because there seems to be a bug in the Metal concurrency optimization + // without this line, the results are non-deterministic and wrong + cur->src[2] = attn_out; + cur = ggml_gelu(ctx0, cur); cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); } @@ -2572,9 +2579,11 @@ static struct ggml_cgraph * llm_build_falcon( inpL = cur; } + cur = inpL; + // norm { - cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_norm(ctx0, cur, norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), From e7299656bd5a21ed49f448186110a473854ab13d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 23 Aug 2023 14:51:30 +0200 Subject: [PATCH 27/36] falcon : add CUDA offloading (#2739) --- llama.cpp | 112 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 101 insertions(+), 11 deletions(-) diff --git a/llama.cpp b/llama.cpp index 112843d202418..10cf75e0dde2c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1860,31 +1860,54 @@ static void llm_load_tensors( // output { - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); - model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); } const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { - layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU); - layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU); + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); } - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, GGML_BACKEND_CPU); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, GGML_BACKEND_CPU); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, GGML_BACKEND_CPU); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, GGML_BACKEND_CPU); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); } } break; default: @@ -2390,6 +2413,8 @@ static struct ggml_cgraph * llm_build_falcon( const float freq_scale = hparams.rope_freq_scale; const float norm_eps = hparams.f_norm_eps; + const int n_gpu_layers = model.n_gpu_layers; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2430,6 +2455,30 @@ static struct ggml_cgraph * llm_build_falcon( } } + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + // + // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal + // in that case ggml_cuda_assign_buffers has no effect + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -2440,21 +2489,35 @@ static struct ggml_cgraph * llm_build_falcon( for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + // self-attention // TODO: refactor into common function (shared with LLaMA) { attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); attn_norm = ggml_add(ctx0, ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + offload_func(attn_norm->src[0]); + offload_func(attn_norm); if (model.layers[il].attn_norm_2) { // Falcon-40B cur = ggml_norm(ctx0, inpL, norm_eps); + offload_func(cur); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), model.layers[il].attn_norm_2_b); + offload_func(cur->src[0]); + offload_func(cur); } else { // Falcon 7B cur = attn_norm; } @@ -2462,6 +2525,7 @@ static struct ggml_cgraph * llm_build_falcon( // compute QKV cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); // Note that the strides for Kcur, Vcur are set up so that the // resulting views are misaligned with the tensor's storage @@ -2479,39 +2543,49 @@ static struct ggml_cgraph * llm_build_falcon( wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), 0); + offload_func_kq(tmpq); struct ggml_tensor * tmpk = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * n_head); + offload_func_kq(tmpk); struct ggml_tensor * tmpv = ggml_view_3d( ctx0, cur, n_embd_head, n_head_kv, N, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_v(tmpv); // using mode = 2 for neox mode struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Qcur); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Kcur); { struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + offload_func_kq(k); ggml_set_name(k, "k"); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + offload_func_v(v); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = @@ -2520,18 +2594,23 @@ static struct ggml_cgraph * llm_build_falcon( ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); ggml_set_name(K, "K"); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); struct ggml_tensor * V = @@ -2540,18 +2619,23 @@ static struct ggml_cgraph * llm_build_falcon( ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); ggml_set_name(V, "V"); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); ggml_set_name(KQV, "KQV"); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); ggml_set_name(cur, "result_wo"); } @@ -2567,13 +2651,18 @@ static struct ggml_cgraph * llm_build_falcon( // adding this, because there seems to be a bug in the Metal concurrency optimization // without this line, the results are non-deterministic and wrong cur->src[2] = attn_out; + offload_func(cur); cur = ggml_gelu(ctx0, cur); + offload_func(cur); cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); } cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); // input for next layer inpL = cur; @@ -2584,6 +2673,7 @@ static struct ggml_cgraph * llm_build_falcon( // norm { cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), From 176ea716b355cff5a2b6a97c7648cee138183818 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 15:53:41 +0300 Subject: [PATCH 28/36] llama : better model naming and size reporting --- convert-falcon-hf-to-gguf.py | 2 +- convert.py | 6 ++++- llama.cpp | 50 ++++++++++++++++-------------------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py index 608f8c455e0df..27beeb542d038 100644 --- a/convert-falcon-hf-to-gguf.py +++ b/convert-falcon-hf-to-gguf.py @@ -94,7 +94,7 @@ def count_model_parts(dir_model: str) -> int: block_count = hparams["n_layer"] -gguf_writer.add_name(last_dir) +gguf_writer.add_name("Falcon") gguf_writer.add_context_length(2048) # not in config.json gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform gguf_writer.add_embedding_length(hparams["hidden_size"]) diff --git a/convert.py b/convert.py index 71978d6716183..b30657a3b270e 100644 --- a/convert.py +++ b/convert.py @@ -733,7 +733,11 @@ def __init__(self, fname_out: Path) -> None: self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) def add_meta_arch(self, params: Params) -> None: - self.gguf.add_name ("LLaMA") + ver = None + if (params.n_ctx == 4096): + ver = "v2" + + self.gguf.add_name ("LLaMA" if ver == None else "LLaMA " + ver) self.gguf.add_context_length (params.n_ctx) self.gguf.add_embedding_length (params.n_embd) self.gguf.add_block_count (params.n_layer) diff --git a/llama.cpp b/llama.cpp index 10cf75e0dde2c..1bdea9de37ccd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -811,6 +811,7 @@ enum e_model { MODEL_7B, MODEL_13B, MODEL_30B, + MODEL_40B, MODEL_65B, MODEL_70B, }; @@ -1489,9 +1490,10 @@ static const char * llama_model_type_name(e_model type) { case MODEL_7B: return "7B"; case MODEL_13B: return "13B"; case MODEL_30B: return "30B"; + case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; - default: GGML_ASSERT(false); + default: return "?B"; } } @@ -1555,40 +1557,29 @@ static void llm_load_hparams( case LLM_ARCH_LLAMA: { GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } } break; case LLM_ARCH_FALCON: { GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 60: model.type = e_model::MODEL_40B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } } break; default: (void)0; }; - // TODO: generalize to non-LLaMA models - switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: - { - if (hparams.n_layer < 32) { - model.type = e_model::MODEL_7B; - } - } break; - } - - // LLaMAv2 - // TODO: probably not needed - { - const auto n_gqa = hparams.n_gqa(); - - if (model.type == e_model::MODEL_65B && n_gqa == 8) { - LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); - model.type = e_model::MODEL_70B; - } - } - model.ftype = ml.ftype; hparams.n_ctx = n_ctx; @@ -5015,7 +5006,10 @@ int llama_model_n_embd(const struct llama_model * model) { } int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) { - return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str()); + return snprintf(buf, buf_size, "%s %s %s", + model->name.c_str(), + llama_model_type_name(model->type), + llama_model_ftype_name(model->ftype).c_str()); } int llama_model_quantize( From c3f8a6e49fa3bce98084a38cff7377767c612974 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 19:08:44 +0300 Subject: [PATCH 29/36] llama : prep new tokenizer support --- llama.cpp | 122 +++++++++++++++++++++++++++--------------------------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/llama.cpp b/llama.cpp index 016e291e63e6c..5e7a7600fdc17 100644 --- a/llama.cpp +++ b/llama.cpp @@ -106,6 +106,12 @@ static void llama_log_callback_default(llama_log_level level, const char * text, // helpers // +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -2948,47 +2954,41 @@ static std::string llama_unescape_whitespace(const std::string& word) { return word; } -static size_t utf8_len(char src) { - const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t highbits = static_cast(src) >> 4; - return lookup[highbits]; -} - -struct llama_sp_symbol { - using index = int; - index prev; - index next; - const char * text; - size_t n; -}; +// original implementation: +// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +struct llama_tokenizer { + struct sp_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; + }; -static_assert(std::is_trivially_copyable::value, "llama_sp_symbol is not trivially copyable"); + static_assert(std::is_trivially_copyable::value, "sp_symbol is not trivially copyable"); -struct llama_sp_bigram { - struct comparator { - bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { - return (l.score < r.score) || (l.score == r.score && l.left > r.left); - } + struct sp_bigram { + struct comparator { + bool operator()(sp_bigram & l, sp_bigram & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + sp_symbol::index left; + sp_symbol::index right; + float score; + size_t size; }; - using queue_storage = std::vector; - using queue = std::priority_queue; - llama_sp_symbol::index left; - llama_sp_symbol::index right; - float score; - size_t size; -}; -// original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 -struct llama_tokenizer { - llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} + llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; while (offs < text.size()) { - llama_sp_symbol sym; + sp_symbol sym; size_t len = utf8_len(text[offs]); GGML_ASSERT(offs + len <= text.size()); sym.text = text.c_str() + offs; @@ -2997,21 +2997,21 @@ struct llama_tokenizer { sym.prev = index - 1; sym.next = offs == text.size() ? -1 : index + 1; index++; - symbols_.emplace_back(sym); + symbols.emplace_back(sym); } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols_.size(); ++i) { + for (size_t i = 1; i < symbols.size(); ++i) { try_add_bigram(i - 1, i); } // keep substituting the highest frequency pairs for as long as we can. - while (!work_queue_.empty()) { - auto bigram = work_queue_.top(); - work_queue_.pop(); + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); - auto & left_sym = symbols_[bigram.left]; - auto & right_sym = symbols_[bigram.right]; + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; // if one of the symbols already got merged, skip it. if (left_sym.n == 0 || right_sym.n == 0 || @@ -3028,7 +3028,7 @@ struct llama_tokenizer { // remove the right sym from the chain left_sym.next = right_sym.next; if (right_sym.next >= 0) { - symbols_[right_sym.next].prev = bigram.left; + symbols[right_sym.next].prev = bigram.left; } // find more substitutions @@ -3036,19 +3036,19 @@ struct llama_tokenizer { try_add_bigram(bigram.left, left_sym.next); } - for (int i = 0; i != -1; i = symbols_[i].next) { - auto & symbol = symbols_[i]; + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; resegment(symbol, output); } } private: - void resegment(llama_sp_symbol &symbol, std::vector &output) { + void resegment(sp_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); - auto token = vocab_.token_to_id.find(text); + auto token = vocab.token_to_id.find(text); // Do we need to support is_unused? - if (token != vocab_.token_to_id.end()) { + if (token != vocab.token_to_id.end()) { output.push_back((*token).second); return; } @@ -3058,14 +3058,14 @@ struct llama_tokenizer { if (p == rev_merge.end()) { // output any symbols that did not form tokens as bytes. for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]); + llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); output.push_back(token_id); } return; } - resegment(symbols_[p->second.first], output); - resegment(symbols_[p->second.second], output); + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); } void try_add_bigram(int left, int right) { @@ -3073,34 +3073,36 @@ struct llama_tokenizer { return; } - const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n); - auto token = vocab_.token_to_id.find(text); + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.token_to_id.find(text); - if (token == vocab_.token_to_id.end()) { + if (token == vocab.token_to_id.end()) { return; } - if (static_cast((*token).second) >= vocab_.id_to_token.size()) { + if (static_cast((*token).second) >= vocab.id_to_token.size()) { return; } - const auto &tok_data = vocab_.id_to_token[(*token).second]; + const auto & tok_data = vocab.id_to_token[(*token).second]; - llama_sp_bigram bigram; - bigram.left = left; + sp_bigram bigram; + bigram.left = left; bigram.right = right; bigram.score = tok_data.score; - bigram.size = text.size(); - work_queue_.push(bigram); + bigram.size = text.size(); + + work_queue.push(bigram); // Do we need to support is_unused? rev_merge[text] = std::make_pair(left, right); } - const llama_vocab & vocab_; - std::vector symbols_; - llama_sp_bigram::queue work_queue_; - std::map > rev_merge; + const llama_vocab & vocab; + + std::vector symbols; + sp_bigram::queue work_queue; + std::map> rev_merge; }; static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { From 3bfb72064283914e68dbac8e80ff0411b790be7b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 20:11:45 +0300 Subject: [PATCH 30/36] llama : advanced BPE tokenizer based on ggllm.cpp imlpementation --- examples/main/main.cpp | 14 +- examples/perplexity/perplexity.cpp | 27 ++- llama.cpp | 340 +++++++++++++++++++++++++---- llama.h | 2 + 4 files changed, 326 insertions(+), 57 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 388e1f7d7fe02..f271303c1e98c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -43,7 +43,7 @@ static bool is_interacting = false; void sigint_handler(int signo) { if (signo == SIGINT) { if (!is_interacting) { - is_interacting=true; + is_interacting = true; } else { console::cleanup(); printf("\n"); @@ -189,10 +189,12 @@ int main(int argc, char ** argv) { } } + const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + // tokenize the prompt std::vector embd_inp; if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { - embd_inp = ::llama_tokenize(ctx, params.prompt, true); + embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm); } else { embd_inp = session_tokens; } @@ -203,9 +205,9 @@ int main(int argc, char ** argv) { int original_prompt_len = 0; if (ctx_guidance) { params.cfg_negative_prompt.insert(0, 1, ' '); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true); + guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm); - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true); + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, is_spm); original_prompt_len = original_inp.size(); guidance_offset = (int)guidance_inp.size() - original_prompt_len; } @@ -252,8 +254,8 @@ int main(int argc, char ** argv) { } // prefix & suffix for instruct mode - const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); - const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index e89725efc3db6..9a39529d50875 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -28,7 +28,6 @@ std::vector softmax(const std::vector& logits) { } void perplexity_v2(llama_context * ctx, const gpt_params & params) { - // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` @@ -38,7 +37,11 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); return; } - auto tokens = ::llama_tokenize(ctx, params.prompt, true); + + const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = is_spm; + + auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); const int calc_chunk = params.n_ctx; @@ -86,7 +89,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { const auto token_org = tokens[batch_start]; // add BOS token for the first batch of each chunk - if (j == 0) { + if (add_bos && j == 0) { tokens[batch_start] = llama_token_bos(ctx); } @@ -136,7 +139,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { } void perplexity(llama_context * ctx, const gpt_params & params) { - if (params.ppl_stride > 0) { perplexity_v2(ctx, params); return; @@ -146,7 +148,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - auto tokens = ::llama_tokenize(ctx, params.prompt, true); + + const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = is_spm; + + auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); const int n_chunk_max = tokens.size() / params.n_ctx; @@ -177,7 +183,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { const auto token_org = tokens[batch_start]; // add BOS token for the first batch of each chunk - if (j == 0) { + if (add_bos && j == 0) { tokens[batch_start] = llama_token_bos(ctx); } @@ -295,8 +301,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { size_t hs_task_count = prompt_lines.size()/6; fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); + const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + // This is needed as usual for LLaMA models - bool prepend_bos = true; + const bool add_bos = is_spm; // Number of tasks to use when computing the score if ( params.hellaswag_tasks < hs_task_count ) { @@ -352,14 +360,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::vector tok_logits(n_vocab); for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { - // Tokenize the context to count tokens - std::vector context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos); + std::vector context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos); size_t context_size = context_embd.size(); // Do the 1st ending // In this case we include the context when evaluating - auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos); + auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos); auto query_size = query_embd.size(); //printf("First query: %d\n",(int)query_size); diff --git a/llama.cpp b/llama.cpp index 5e7a7600fdc17..11dad7c8a0140 100644 --- a/llama.cpp +++ b/llama.cpp @@ -72,6 +72,7 @@ #include #include #include +#include #include #include #include @@ -112,6 +113,15 @@ static size_t utf8_len(char src) { return lookup[highbits]; } +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + for (size_t pos = 0; ; pos += replace.length()) { + pos = s.find(search, pos); + if (pos == std::string::npos) break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -929,11 +939,13 @@ struct llama_vocab { ttype type; }; - llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; std::unordered_map token_to_id; std::vector id_to_token; + std::map, int> bpe_ranks; + // default LLaMA special tokens id special_bos_id = 1; id special_eos_id = 2; @@ -942,6 +954,20 @@ struct llama_vocab { id special_pad_id = -1; id linefeed_id = 13; + + int find_bpe_rank(std::string token_left, std::string token_right) const { + replace_all(token_left, " ", "Ġ"); + replace_all(token_left, "\n", "Ċ"); + replace_all(token_right, " ", "Ġ"); + replace_all(token_right, "\n", "Ċ"); + + auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); + if (it == bpe_ranks.end()) { + return -1; + } + + return it->second; + } }; struct llama_model { @@ -1634,6 +1660,30 @@ static void llm_load_vocab( vocab.type = LLAMA_VOCAB_TYPE_SPM; } else if (tokenizer_name == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; + + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + // populate bpe ranks + vocab.bpe_ranks.emplace(std::make_pair(first, second), i); + } } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -1654,7 +1704,6 @@ static void llm_load_vocab( token_data.text = std::move(word); token_data.score = scores[i]; token_data.type = (llama_token_type) toktypes[i]; - } // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -1677,6 +1726,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); @@ -2954,41 +3004,43 @@ static std::string llama_unescape_whitespace(const std::string& word) { return word; } +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); + +// SPM tokenizer // original implementation: // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 -struct llama_tokenizer { - struct sp_symbol { - using index = int; - index prev; - index next; - const char * text; - size_t n; - }; - static_assert(std::is_trivially_copyable::value, "sp_symbol is not trivially copyable"); - - struct sp_bigram { - struct comparator { - bool operator()(sp_bigram & l, sp_bigram & r) { - return (l.score < r.score) || (l.score == r.score && l.left > r.left); - } - }; - using queue_storage = std::vector; - using queue = std::priority_queue; - sp_symbol::index left; - sp_symbol::index right; - float score; - size_t size; +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; - llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_spm { + llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; while (offs < text.size()) { - sp_symbol sym; + llm_symbol sym; size_t len = utf8_len(text[offs]); GGML_ASSERT(offs + len <= text.size()); sym.text = text.c_str() + offs; @@ -3043,7 +3095,7 @@ struct llama_tokenizer { } private: - void resegment(sp_symbol & symbol, std::vector & output) { + void resegment(llm_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); auto token = vocab.token_to_id.find(text); @@ -3086,7 +3138,7 @@ struct llama_tokenizer { const auto & tok_data = vocab.id_to_token[(*token).second]; - sp_bigram bigram; + llm_bigram_spm bigram; bigram.left = left; bigram.right = right; bigram.score = tok_data.score; @@ -3100,31 +3152,233 @@ struct llama_tokenizer { const llama_vocab & vocab; - std::vector symbols; - sp_bigram::queue work_queue; + std::vector symbols; + llm_bigram_spm::queue work_queue; + std::map> rev_merge; }; +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +struct llm_bigram_bpe { + struct comparator { + bool operator()(llm_bigram_bpe & l, llm_bigram_bpe & r) { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe { + llm_tokenizer_bpe(const llama_vocab & vocab, bool g2ws): vocab(vocab) { flag_g2ws = g2ws; } + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + auto word_collection = bpe_gpt2_preprocess(text); + + symbols_final.clear(); + + for (auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); + sym.text = word.c_str() + offset; + sym.n = 1; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (size_t i = 1; i < symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the fnished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.token_to_id.find(str); + + if (token == vocab.token_to_id.end()) { + for (auto j = str.begin(); j != str.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab.token_to_id.find(byte_str); + if (token_multibyte == vocab.token_to_id.end()) { + fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str()); + } + output.push_back((*token_multibyte).second); + } + } else { + output.push_back((*token).second); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + // probably not 100% correct + static std::vector bpe_gpt2_preprocess(std::string text) { + std::vector words; + + // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 + const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + const std::regex re(pattern); + std::smatch m; + + while (std::regex_search(text, m, re)) { + for (auto x : m) { + words.push_back(x); + } + text = m.suffix(); + } + + return words; + } + + bool flag_g2ws = false; + + const llama_vocab & vocab; + + std::vector symbols; + std::vector symbols_final; + + llm_bigram_bpe::queue work_queue; +}; + static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { - llama_tokenizer tokenizer(vocab); std::vector output; if (raw_text.empty()) { return output; } - if (bos) { - output.push_back(vocab.special_bos_id); - } + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_SPM: + { + llm_tokenizer_spm tokenizer(vocab); - std::string text; - if (escape) { - text = llama_escape_whitespace(raw_text); - } else { - text = raw_text; - } + if (bos) { + output.push_back(vocab.special_bos_id); + } + + std::string text; + if (escape) { + text = llama_escape_whitespace(raw_text); + } else { + text = raw_text; + } + + tokenizer.tokenize(text, output); + } break; + case LLAMA_VOCAB_TYPE_BPE: + { + llm_tokenizer_bpe tokenizer(vocab, escape); + + if (bos && vocab.special_bos_id != -1) { + output.push_back(vocab.special_bos_id); + } + + tokenizer.tokenize(raw_text, output); + } break; + }; - tokenizer.tokenize(text, output); return output; } @@ -4988,6 +5242,10 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } +enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { + return ctx->model.vocab.type; +} + int llama_model_n_vocab(const struct llama_model * model) { return model->vocab.id_to_token.size(); } diff --git a/llama.h b/llama.h index 7ce478d5452a7..608e41bf431d3 100644 --- a/llama.h +++ b/llama.h @@ -247,6 +247,8 @@ extern "C" { LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); + LLAMA_API int llama_model_n_vocab(const struct llama_model * model); LLAMA_API int llama_model_n_ctx (const struct llama_model * model); LLAMA_API int llama_model_n_embd (const struct llama_model * model); From 2424e1d08ed8171a51d4e7ead52b3943824e3a80 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 20:16:40 +0300 Subject: [PATCH 31/36] llama : remove oboslete comment ggml-ci --- llama.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 11dad7c8a0140..9d3ae4e97a2da 100644 --- a/llama.cpp +++ b/llama.cpp @@ -925,10 +925,6 @@ struct llama_kv_cache { }; struct llama_vocab { - // TODO: - // - add a vector of merges - // so that we can pass it to different types of tokenizers with a common interface - using id = int32_t; using token = std::string; using ttype = llama_token_type; From 596e1094fbe4d0b6040cbcebc1efb327d535cad1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 20:31:03 +0300 Subject: [PATCH 32/36] common : remove obsolete BPE API + disable test-tokenizer-1 --- common/common.cpp | 32 -------------------------------- common/common.h | 9 --------- llama.cpp | 32 -------------------------------- llama.h | 13 ------------- tests/CMakeLists.txt | 3 ++- tests/test-tokenizer-1.cpp | 16 ++++------------ 6 files changed, 6 insertions(+), 99 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 88a962ae385de..53002ba306b57 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok return std::string(result.data(), result.size()); } - -std::vector llama_tokenize_bpe( - struct llama_context * ctx, - const std::string & text, - bool add_bos) { - int n_tokens = text.length() + add_bos; - std::vector result(n_tokens); - n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - return result; -} - -std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); - if (n_tokens < 0) { - result.resize(-n_tokens); - const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - - return std::string(result.data(), result.size()); -} - diff --git a/common/common.h b/common/common.h index d68a8ef88c97c..17d271e6750e2 100644 --- a/common/common.h +++ b/common/common.h @@ -120,15 +120,6 @@ std::vector llama_tokenize( const std::string & text, bool add_bos); -std::vector llama_tokenize_bpe( - struct llama_context * ctx, - const std::string & text, - bool add_bos); - std::string llama_token_to_str( const struct llama_context * ctx, llama_token token); - -std::string llama_token_to_str_bpe( - const struct llama_context * ctx, - llama_token token); diff --git a/llama.cpp b/llama.cpp index 9d3ae4e97a2da..58b7b70d32cdd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5779,26 +5779,6 @@ int llama_tokenize( return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); } -int llama_tokenize_bpe( - struct llama_context * ctx, - const char * text, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - auto res = llama_tokenize_internal(ctx->model.vocab, text, add_bos, false); - - if (n_max_tokens < (int) res.size()) { - LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); - return -((int) res.size()); - } - - for (size_t i = 0; i < res.size(); i++) { - tokens[i] = res[i]; - } - - return res.size(); -} - int llama_tokenize_with_model( const struct llama_model * model, const char * text, @@ -5824,18 +5804,6 @@ int llama_token_to_str(const struct llama_context * ctx, llama_token token, char return llama_token_to_str_with_model(&ctx->model, token, buf, length); } -int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * buf, int length) { - if (0 <= token && token < llama_model_n_vocab(&ctx->model)) { - std::string result = ctx->model.vocab.id_to_token[token].text; - if (length < (int) result.length()) { - return -result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } - return 0; -} - // does not write null-terminator to str int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { if (0 <= token && token < llama_model_n_vocab(model)) { diff --git a/llama.h b/llama.h index 608e41bf431d3..4e7638c042de9 100644 --- a/llama.h +++ b/llama.h @@ -370,13 +370,6 @@ extern "C" { int n_max_tokens, bool add_bos); - LLAMA_API int llama_tokenize_bpe( - struct llama_context * ctx, - const char * text, - llama_token * tokens, - int n_max_tokens, - bool add_bos); - LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, @@ -392,12 +385,6 @@ extern "C" { char * buf, int length); - LLAMA_API int llama_token_to_str_bpe( - const struct llama_context * ctx, - llama_token token, - char * buf, - int length); - LLAMA_API int llama_token_to_str_with_model( const struct llama_model * model, llama_token token, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4ccefe9322322..2afaf86b11450 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,7 +28,8 @@ llama_build_and_test_executable(test-sampling.cpp) llama_build_executable(test-tokenizer-0.cpp) llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_build_executable(test-tokenizer-1.cpp) -llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) +# test-tokenizer-1 requires a BPE vocab. re-enable when we have one. +#llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) #llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) llama_build_and_test_executable(test-grammar-parser.cpp) llama_build_and_test_executable(test-llama-grammar.cpp) diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp index 993d17f1833d3..bd607d12bb1cd 100644 --- a/tests/test-tokenizer-1.cpp +++ b/tests/test-tokenizer-1.cpp @@ -67,11 +67,13 @@ int main(int argc, char **argv) { } } + GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_BPE); + const int n_vocab = llama_n_vocab(ctx); for (int i = 0; i < n_vocab; ++i) { - std::string forward = llama_token_to_str_bpe(ctx, i); - std::vector tokens = llama_tokenize_bpe(ctx, forward, false); + std::string forward = llama_token_to_str(ctx, i); + std::vector tokens = llama_tokenize(ctx, forward, false); if (tokens.size() == 1) { if (i != tokens[0]) { std::string backward = llama_token_to_str(ctx, tokens[0]); @@ -79,16 +81,6 @@ int main(int argc, char **argv) { __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); return 2; } - } else { - llama_token_type type = llama_token_get_type(ctx, i); - if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) { - fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", - __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); - } else { - fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n", - __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); - return 2; - } } } From f8ee54bd2c09aad29bcfd3b619e66e5979717836 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 20:39:24 +0300 Subject: [PATCH 33/36] llama : revert BPE special-case in llama_byte_to_token() --- llama.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index 58b7b70d32cdd..d3adda11bcb3a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2971,14 +2971,10 @@ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - GGML_ASSERT(0 <= result && result < 7); - return vocab.token_to_id.at(buf); - } - // vocab.type == LLAMA_VOCAB_TYPE_BPE - return vocab.token_to_id.at(std::string(1, ch)); + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); } static std::string llama_escape_whitespace(const std::string& text) { From 8c6d3939c793329850a08c141e612132c74012c6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 21:32:12 +0300 Subject: [PATCH 34/36] cuda : add TODOs for RoPE NeoX implementation --- ggml-cuda.cu | 29 ++++++++++++++++++++++++++++- ggml.c | 4 ++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 70a950bb58b9b..868b7a7b905a2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } +// TODO: this implementation is wrong! +//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, +// const float p_delta, const int p_delta_rows, const float theta_scale) { +// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); +// +// if (col >= ncols) { +// return; +// } +// +// const int row = blockDim.x*blockIdx.x + threadIdx.x; +// const int i = row*ncols + col/2; +// +// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); +// const float sin_theta = sinf(theta); +// const float cos_theta = cosf(theta); +// +// const float x0 = x[i + 0]; +// const float x1 = x[i + ncols/2]; +// +// dst[i + 0] = x0*cos_theta - x1*sin_theta; +// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; +//} + static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -5515,7 +5538,8 @@ inline void ggml_cuda_op_rope( const float theta_scale = powf(freq_base, -2.0f/n_dims); - const bool is_glm = mode & 4; + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; // compute if (is_glm) { @@ -5523,6 +5547,9 @@ inline void ggml_cuda_op_rope( const float id_p = min(p, n_ctx - 2.f); const float block_p = max(p - (n_ctx - 2.f), 0.f); rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); + } else if (is_neox) { + GGML_ASSERT(false && "RoPE NeoX not implemented yet"); +#pragma message("TODO: implement RoPE NeoX for CUDA") } else { const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); diff --git a/ggml.c b/ggml.c index 6839d03213559..8cb5c404f285d 100644 --- a/ggml.c +++ b/ggml.c @@ -12537,7 +12537,7 @@ static void ggml_compute_forward_rope_f32( dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; } } else { - // TODO: this is probably wrong, but I can't figure it out .. + // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { @@ -12666,7 +12666,7 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // TODO: this is probably wrong, but I can't figure it out .. + // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { From 630d8b408adc3fff6ed796ac26e70a2725f4fb3c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 21:39:09 +0300 Subject: [PATCH 35/36] llama : default special tokens based on vocab type --- llama.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index d3adda11bcb3a..bdf2437015b52 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1654,9 +1654,17 @@ static void llm_load_vocab( if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; } else if (tokenizer_name == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; + // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); @@ -1677,12 +1685,19 @@ static void llm_load_vocab( second = word.substr(pos + 1); } - // populate bpe ranks vocab.bpe_ranks.emplace(std::make_pair(first, second), i); } + + // default special tokens + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; } } From fae8faa135942918a7c9ecc8c2fc26be7f61140d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Aug 2023 22:56:50 +0300 Subject: [PATCH 36/36] perplexity : add log for start of tokenization --- examples/perplexity/perplexity.cpp | 4 ++++ llama.cpp | 1 + 2 files changed, 5 insertions(+) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9a39529d50875..a7bd9db2a3fd3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -41,6 +41,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; + fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); const int calc_chunk = params.n_ctx; @@ -152,6 +154,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) { const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; + fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); const int n_chunk_max = tokens.size() / params.n_ctx; diff --git a/llama.cpp b/llama.cpp index bdf2437015b52..f2dc4da1db344 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3321,6 +3321,7 @@ struct llm_tokenizer_bpe { } // probably not 100% correct + // TODO: this is quite slow - how to make it more efficient? static std::vector bpe_gpt2_preprocess(std::string text) { std::vector words;