From 54dad2037eb41b6afcd8b11665927ee24cbe2296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 28 Feb 2025 15:28:24 +0100 Subject: [PATCH] Fix: preprocess tensor names in tensor types map --- model.cpp | 51 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/model.cpp b/model.cpp index dcbaae5bc..f3c5ae576 100644 --- a/model.cpp +++ b/model.cpp @@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) { return new_name; } +void add_preprocess_tensor_storage_types(std::map& tensor_storages_types, std::string name, enum ggml_type type) { + std::string new_name = convert_tensor_name(name); + + if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { + size_t prefix_size = new_name.find("attn.in_proj_weight"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type; + } else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) { + size_t prefix_size = new_name.find("attn.in_proj_bias"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type; + } else { + tensor_storages_types[new_name] = type; + } +} + void preprocess_tensor(TensorStorage tensor_storage, std::vector& processed_tensor_storages) { std::vector result; @@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); tensor_storages.push_back(tensor_storage); - tensor_storages_types[tensor_storage.name] = tensor_storage.type; + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); } gguf_free(ctx_gguf_); @@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } tensor_storages.push_back(tensor_storage); - tensor_storages_types[tensor_storage.name] = tensor_storage.type; + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); } @@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); - tensor_storages_types[reader.tensor_storage.name] = reader.tensor_storage.type; + add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset @@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; bool input_block_checked = false; - bool has_multiple_encoders = false; - bool is_unet = false; + bool has_multiple_encoders = false; + bool is_unet = false; - bool is_xl = false; + bool is_xl = false; bool is_flux = false; #define found_family (is_xl || is_flux) @@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() { } if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { is_unet = true; - if(has_multiple_encoders){ + if (has_multiple_encoders) { is_xl = true; if (input_block_checked) { break; @@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() { } if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { has_multiple_encoders = true; - if(is_unet){ + if (is_unet) { is_xl = true; if (input_block_checked) { break; @@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() { void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { for (auto& pair : tensor_storages_types) { if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) { + bool found = false; for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name == pair.first) { - if (tensor_should_be_converted(tensor_storage, wtype)) { - pair.second = wtype; + std::map temp; + add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type); + for (auto& preprocessed_name : temp) { + if (preprocessed_name.first == pair.first) { + if (tensor_should_be_converted(tensor_storage, wtype)) { + pair.second = wtype; + } + found = true; + break; } + } + if (found) { break; } }