diff --git a/clip.hpp b/clip.hpp index 2307ee3c..059b7d0d 100644 --- a/clip.hpp +++ b/clip.hpp @@ -661,6 +661,7 @@ class CLIPTextModel : public GGMLBlock { if (version == OPEN_CLIP_VIT_BIGG_14) { enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32; params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size); + ggml_set_name(params["text_projection"], (prefix + "text_projection").c_str()); } } @@ -812,6 +813,7 @@ class CLIPProjection : public UnaryBlock { } else { params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); } + ggml_set_name(params["weight"], (prefix + "weight").c_str()); } public: diff --git a/docs/imatrix.md b/docs/imatrix.md new file mode 100644 index 00000000..e07b85ab --- /dev/null +++ b/docs/imatrix.md @@ -0,0 +1,59 @@ +# Importance Matrix (imatrix) Quantization + +## What is an Importance Matrix? + +Quantization reduces the precision of a model's weights, decreasing its size and computational requirements. However, this can lead to a loss of quality. An importance matrix helps mitigate this by identifying which weights are *most* important for the model's performance. During quantization, these important weights are preserved with higher precision, while less important weights are quantized more aggressively. This allows for better overall quality at a given quantization level. + +This originates from work done with language models in [llama.cpp](https://github.com/ggml-org/llama.cpp/blob/master/examples/imatrix/README.md). + +## Usage + +The imatrix feature involves two main steps: *training* the matrix and *using* it during quantization. + +### Training the Importance Matrix + +To generate an imatrix, run stable-diffusion.cpp with the `--imat-out` flag, specifying the output filename. This process runs alongside normal image generation. + +```bash +sd.exe [same exact parameters as normal generation] --imat-out imatrix.dat +``` + +* **`[same exact parameters as normal generation]`**: Use the same command-line arguments you would normally use for image generation (e.g., prompt, dimensions, sampling method, etc.). +* **`--imat-out imatrix.dat`**: Specifies the output file for the generated imatrix. + +You can generate multiple images at once using the `-b` flag to speed up the training process. + +### Continuing Training an Existing Matrix + +If you want to refine an existing imatrix, use the `--imat-in` flag *in addition* to `--imat-out`. This will load the existing matrix and continue training it. + +```bash +sd.exe [same exact parameters as normal generation] --imat-out imatrix.dat --imat-in imatrix.dat +``` +With that, you can train and refine the imatrix while generating images like you'd normally do. + +### Using Multiple Matrices + +You can load and merge multiple imatrices together: + +```bash +sd.exe [same exact parameters as normal generation] --imat-out imatrix.dat --imat-in imatrix.dat --imat-in imatrix2.dat +``` + +### Quantizing with an Importance Matrix + +To quantize a model using a trained imatrix, use the `-M convert` option (or equivalent quantization command) and the `--imat-in` flag, specifying the imatrix file. + +```bash +sd.exe -M convert [same exact parameters as normal quantization] --imat-in imatrix.dat +``` + +* **`[same exact parameters as normal quantization]`**: Use the same command-line arguments you would normally use for quantization (e.g., target quantization method, input/output filenames). +* **`--imat-in imatrix.dat`**: Specifies the imatrix file to use during quantization. You can specify multiple `--imat-in` flags to combine multiple matrices. + +## Important Considerations + +* The quality of the imatrix depends on the prompts and settings used during training. Use prompts and settings representative of the types of images you intend to generate for the best results. +* Experiment with different training parameters (e.g., number of images, prompt variations) to optimize the imatrix for your specific use case. +* The performance impact of training an imatrix during image generation or using an imatrix for quantization is negligible. +* Using already quantized models to train the imatrix seems to be working fine. \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af6b2bbd..15cce84d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -129,6 +129,12 @@ struct SDParams { float slg_scale = 0.f; float skip_layer_start = 0.01f; float skip_layer_end = 0.2f; + + /* Imatrix params */ + + std::string imatrix_out = ""; + + std::vector imatrix_in = {}; }; void print_params(SDParams params) { @@ -204,6 +210,8 @@ void print_usage(int argc, const char* argv[]) { printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); + printf(" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path\n"); + printf(" --imat-in [PATH] Use imatrix for quantization.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); @@ -250,6 +258,7 @@ void print_usage(int argc, const char* argv[]) { void parse_args(int argc, const char** argv, SDParams& params) { bool invalid_arg = false; std::string arg; + std::string type = ""; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -355,32 +364,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { invalid_arg = true; break; } - std::string type = argv[i]; - bool found = false; - std::string valid_types = ""; - for (size_t i = 0; i < SD_TYPE_COUNT; i++) { - auto trait = ggml_get_type_traits((ggml_type)i); - std::string name(trait->type_name); - if (name == "f32" || trait->to_float && trait->type_size) { - if (i) - valid_types += ", "; - valid_types += name; - if (type == name) { - if (ggml_quantize_requires_imatrix((ggml_type)i)) { - printf("\033[35;1m[WARNING]\033[0m: type %s requires imatrix to work properly. A dummy imatrix will be used, expect poor quality.\n", trait->type_name); - } - params.wtype = (enum sd_type_t)i; - found = true; - break; - } - } - } - if (!found) { - fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n", - type.c_str(), - valid_types.c_str()); - exit(1); - } + type = argv[i]; } else if (arg == "--lora-model-dir") { if (++i >= argc) { invalid_arg = true; @@ -629,12 +613,60 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); + } else if (arg == "--imat-out") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.imatrix_out = argv[i]; + } else if (arg == "--imat-in") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.imatrix_in.push_back(std::string(argv[i])); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); exit(1); } } + if (type != "") { + bool found = false; + std::string valid_types = ""; + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + std::string name(trait->type_name); + if (name == "f32" || trait->to_float && trait->type_size) { + if (i) + valid_types += ", "; + valid_types += name; + if (type == name) { + if (ggml_quantize_requires_imatrix((ggml_type)i) && params.imatrix_in.size() == 0) { + printf("\033[35;1m[WARNING]\033[0m: type %s requires imatrix to work properly. A dummy imatrix will be used, expect poor quality.\n", trait->type_name); + } + params.wtype = (enum sd_type_t)i; + found = true; + break; + } + } + } + if (!found) { + fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n", + type.c_str(), + valid_types.c_str()); + exit(1); + } + } + + if (params.imatrix_out.size() > 0 && std::ifstream(params.imatrix_out).good()) { + // imatrix file already exists + if (std::find(params.imatrix_in.begin(), params.imatrix_in.end(), params.imatrix_out) == params.imatrix_in.end()) { + printf("\n IMPORTANT: imatrix file %s already exists, but wasn't found in the imatrix inputs.\n", params.imatrix_out.c_str()); + printf("%s will get overwritten!\n", params.imatrix_out.c_str()); + } + } + if (invalid_arg) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -799,8 +831,20 @@ int main(int argc, const char* argv[]) { printf("%s", sd_get_system_info()); } + if (params.imatrix_out != "") { + enableImatrixCollection(); + } + if (params.imatrix_out != "" || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) { + for (const auto& in_file : params.imatrix_in) { + printf("loading imatrix from '%s'\n", in_file.c_str()); + if (!loadImatrix(in_file.c_str())) { + printf("Failed to load %s\n", in_file.c_str()); + } + } + } + if (params.mode == CONVERT) { - bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); + bool success = convert(params.model_path.c_str(), params.clip_l_path.c_str(), params.clip_g_path.c_str(), params.t5xxl_path.c_str(), params.diffusion_model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); if (!success) { fprintf(stderr, "convert '%s'/'%s' to '%s' failed\n", @@ -1075,11 +1119,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1087,7 +1131,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1099,7 +1143,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); @@ -1111,6 +1155,9 @@ int main(int argc, const char* argv[]) { free(results[i].data); results[i].data = NULL; } + if (params.imatrix_out != "") { + saveImatrix(params.imatrix_out.c_str()); + } free(results); free_sd_ctx(sd_ctx); free(control_image_buffer); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4..74c1ea32 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -23,9 +23,11 @@ #include "ggml-alloc.h" #include "ggml-backend.h" #include "ggml-cpu.h" +#include "ggml/src/ggml-impl.h" #include "ggml.h" #include "model.h" +#include "util.h" #ifdef SD_USE_CUDA #include "ggml-cuda.h" @@ -117,13 +119,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g b); } -__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { - (void)level; - (void)user_data; - fputs(text, stderr); - fflush(stderr); -} - __STATIC_INLINE__ void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr rng) { uint32_t n = (uint32_t)ggml_nelements(tensor); std::vector random_numbers = rng->randn(n); @@ -1241,7 +1236,39 @@ struct GGMLRunner { ggml_backend_cpu_set_n_threads(backend, n_threads); } - ggml_backend_graph_compute(backend, gf); + auto callback_eval = get_callback_eval(); + + if(!callback_eval){ + ggml_backend_graph_compute(backend, gf); + }else{ + void * callback_eval_user_data = get_callback_eval_user_data(); + for (int j0 = 0; j0 < gf->n_nodes; j0++) { + struct ggml_tensor * t = gf->nodes[j0]; + + // check if the user needs data from this node + bool need = callback_eval(t, true, callback_eval_user_data); + + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < gf->n_nodes - 1) { + t = gf->nodes[++j1]; + need = callback_eval(t, true, callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(gf, j0, j1 + 1); + + ggml_backend_graph_compute_async(backend, &gv); + + if (need && !callback_eval(t, false, callback_eval_user_data)) { + break; + } + + j0 = j1; + } + ggml_backend_synchronize(backend); + } + #ifdef GGML_PERF ggml_graph_print(gf); #endif @@ -1345,6 +1372,7 @@ class Linear : public UnaryBlock { wtype = GGML_TYPE_F32; } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); + ggml_set_name(params["weight"], (prefix + "weight").c_str()); if (bias) { enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features); @@ -1508,6 +1536,8 @@ class LayerNorm : public UnaryBlock { if (elementwise_affine) { enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); + ggml_set_name(params["weight"], (prefix + "weight").c_str()); + if (bias) { enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); diff --git a/imatrix.cpp b/imatrix.cpp new file mode 100644 index 00000000..db786fdf --- /dev/null +++ b/imatrix.cpp @@ -0,0 +1,281 @@ +#include "imatrix.hpp" + +/*Stolen from llama.cpp (credits: Kawrakow)*/ + +#include "ggml-backend.h" +#include "ggml.h" +#include "util.h" + +#include + +// remove any prefix and suffixes from the name +// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight +static std::string filter_tensor_name(const char* name) { + std::string wname; + const char* p = strchr(name, '#'); + if (p != NULL) { + p = p + 1; + const char* q = strchr(p, '#'); + if (q != NULL) { + wname = std::string(p, q - p); + } else { + wname = p; + } + } else { + wname = name; + } + return wname; +} + +bool IMatrixCollector::collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) { + GGML_UNUSED(user_data); + const struct ggml_tensor* src0 = t->src[0]; + const struct ggml_tensor* src1 = t->src[1]; + std::string wname = filter_tensor_name(src0->name); + + // when ask is true, the scheduler wants to know if we are interested in data from this tensor + // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection + if (ask) { + if (t->op == GGML_OP_MUL_MAT_ID) + return true; // collect all indirect matrix multiplications + if (t->op != GGML_OP_MUL_MAT) + return false; + // why are small batches ignored (<16 tokens)? + // if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; + if (!(wname.substr(0, 6) == "model." || wname.substr(0, 17) == "cond_stage_model." || wname.substr(0, 14) == "text_encoders.")) + return false; + return true; + } + // LOG_DEBUG("%s", wname.c_str()); + + std::lock_guard lock(m_mutex); + + // copy the data from the GPU memory if needed + const bool is_host = src1->buffer == NULL || ggml_backend_buffer_is_host(src1->buffer); + + if (!is_host) { + m_src1_data.resize(ggml_nelements(src1)); + ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1)); + } + + const float* data = is_host ? (const float*)src1->data : m_src1_data.data(); + + // this has been adapted to the new format of storing merged experts in a single 3d tensor + // ref: https://github.com/ggml-org/llama.cpp/pull/6387 + if (t->op == GGML_OP_MUL_MAT_ID) { + // ids -> [n_experts_used, n_tokens] + // src1 -> [cols, n_expert_used, n_tokens] + const ggml_tensor* ids = t->src[2]; + const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; + + // the top-k selected expert ids are stored in the ids tensor + // for simplicity, always copy ids to host, because it is small + // take into account that ids is not contiguous! + + GGML_ASSERT(ids->ne[1] == src1->ne[2]); + + m_ids.resize(ggml_nbytes(ids)); + ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids)); + + auto& e = m_stats[wname]; + + ++e.ncall; + + if (e.values.empty()) { + e.values.resize(src1->ne[0] * n_as, 0); + e.counts.resize(src1->ne[0] * n_as, 0); + } else if (e.values.size() != (size_t)src1->ne[0] * n_as) { + LOG_ERROR("inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0] * n_as); + exit(1); // GGML_ABORT("fatal error"); + } + // LOG_DEBUG("%s[%d]: %32s, %s, %5d x %5d, %d\n", m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); + // loop over all possible experts, regardless if they are used or not in the batch + for (int ex = 0; ex < n_as; ++ex) { + size_t e_start = ex * src1->ne[0]; + + for (int idx = 0; idx < n_ids; ++idx) { + for (int row = 0; row < (int)src1->ne[2]; ++row) { + const int excur = *(const int32_t*)(m_ids.data() + row * ids->nb[1] + idx * ids->nb[0]); + + GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check + + if (excur != ex) + continue; + + const int64_t i11 = idx % src1->ne[1]; + const int64_t i12 = row; + const float* x = (const float*)((const char*)data + i11 * src1->nb[1] + i12 * src1->nb[2]); + + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[e_start + j] += x[j] * x[j]; + e.counts[e_start + j]++; + if (!std::isfinite(e.values[e_start + j])) { + printf("\n"); + LOG_ERROR("%f detected in %s\n", e.values[e_start + j], wname.c_str()); + exit(1); + } + } + } + } + } + } else { + auto& e = m_stats[wname]; + if (e.values.empty()) { + e.values.resize(src1->ne[0], 0); + e.counts.resize(src1->ne[0], 0); + } else if (e.values.size() != (size_t)src1->ne[0]) { + LOG_WARN("inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); + exit(1); // GGML_ABORT("fatal error"); + } + + ++e.ncall; + // LOG_DEBUG("%s[%d]: %32s, %s, %5d x %5d, %d\n", m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); + for (int row = 0; row < (int)src1->ne[1]; ++row) { + const float* x = data + row * src1->ne[0]; + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[j] += x[j] * x[j]; + e.counts[j]++; + if (!std::isfinite(e.values[j])) { + LOG_WARN("%f detected in %s\n", e.values[j], wname.c_str()); + exit(1); + } + } + } + } + return true; +} + +void IMatrixCollector::save_imatrix(std::string fname, int ncall) const { + LOG_INFO("SAVING_IMATRIX to %s\n", fname.c_str()); + + if (ncall > 0) { + fname += ".at_"; + fname += std::to_string(ncall); + } + // avoid writing imatrix entries that do not have full data + // this can happen with MoE models where some of the experts end up not being exercised by the provided training data + + int n_entries = 0; + std::vector to_store; + + bool is_first = true; // for printing + for (const auto& kv : m_stats) { + const int n_all = kv.second.counts.size(); + + if (n_all == 0) { + continue; + } + + int n_zeros = 0; + for (const int c : kv.second.counts) { + if (c == 0) { + n_zeros++; + } + } + + if (n_zeros != 0 && is_first) { + printf("\n"); + is_first = false; + } + + if (n_zeros == n_all) { + LOG_WARN("entry '%40s' has no data - skipping\n", kv.first.c_str()); + continue; + } + + if (n_zeros > 0) { + LOG_WARN("entry '%40s' has partial data (%.2f%%) - skipping\n", kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + continue; + } + + n_entries++; + to_store.push_back(kv.first); + } + + if (to_store.size() < m_stats.size()) { + LOG_WARN("storing only %zu out of %zu entries\n", to_store.size(), m_stats.size()); + } + + std::ofstream out(fname, std::ios::binary); + out.write((const char*)&n_entries, sizeof(n_entries)); + for (const auto& name : to_store) { + const auto& stat = m_stats.at(name); + int len = name.size(); + out.write((const char*)&len, sizeof(len)); + out.write(name.c_str(), len); + out.write((const char*)&stat.ncall, sizeof(stat.ncall)); + int nval = stat.values.size(); + out.write((const char*)&nval, sizeof(nval)); + if (nval > 0) { + std::vector tmp(nval); + for (int i = 0; i < nval; i++) { + tmp[i] = (stat.values[i] / static_cast(stat.counts[i])) * static_cast(stat.ncall); + } + out.write((const char*)tmp.data(), nval * sizeof(float)); + } + } + + // Write the number of call the matrix was computed with + out.write((const char*)&m_last_call, sizeof(m_last_call)); + + // LOG_DEBUG("\n"); + // LOG_DEBUG("stored collected data after %d chunks in %s\n", m_last_call, fname.c_str()); +} + +bool IMatrixCollector::load_imatrix(const char* fname) { + std::ifstream in(fname, std::ios::binary); + if (!in) { + LOG_ERROR("failed to open %s\n", fname); + return false; + } + int n_entries; + in.read((char*)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + LOG_ERROR("no data in file %s\n", fname); + return false; + } + for (int i = 0; i < n_entries; ++i) { + int len; + in.read((char*)&len, sizeof(len)); + std::vector name_as_vec(len + 1); + in.read((char*)name_as_vec.data(), len); + if (in.fail()) { + LOG_ERROR("failed reading name for entry %d from %s\n", i + 1, fname); + return false; + } + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; + auto& e = m_stats[std::move(name)]; + int ncall; + in.read((char*)&ncall, sizeof(ncall)); + int nval; + in.read((char*)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + LOG_ERROR("failed reading number of values for entry %d\n", i); + m_stats = {}; + return false; + } + + if (e.values.empty()) { + e.values.resize(nval, 0); + e.counts.resize(nval, 0); + } + + std::vector tmp(nval); + in.read((char*)tmp.data(), nval * sizeof(float)); + if (in.fail()) { + LOG_ERROR("failed reading data for entry %d\n", i); + m_stats = {}; + return false; + } + + // Recreate the state as expected by save_imatrix(), and correct for weighted sum. + for (int i = 0; i < nval; i++) { + e.values[i] += tmp[i]; + e.counts[i] += ncall; + } + e.ncall += ncall; + } + return true; +} \ No newline at end of file diff --git a/imatrix.hpp b/imatrix.hpp new file mode 100644 index 00000000..bcfd4e42 --- /dev/null +++ b/imatrix.hpp @@ -0,0 +1,39 @@ +#ifndef IMATRIX_HPP +#define IMATRIX_HPP +#include +#include +#include +#include +#include + +/*Stolen from llama.cpp (credits: Kawrakow)*/ + +struct Stats { + std::vector values{}; + std::vector counts{}; + int ncall = 0; +}; + +class IMatrixCollector { +public: + IMatrixCollector() = default; + bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data); + void save_imatrix(std::string fname, int ncall = -1) const; + bool load_imatrix(const char* fname); + std::vector get_values(const std::string& key) const { + auto it = m_stats.find(key); + if (it != m_stats.end()) { + return it->second.values; + } else { + return {}; + } + } +private: + std::unordered_map m_stats = {}; + std::mutex m_mutex; + int m_last_call = 0; + std::vector m_src1_data; + std::vector m_ids; // the expert ids from ggml_mul_mat_id +}; + +#endif \ No newline at end of file diff --git a/model.cpp b/model.cpp index 24da39f6..7673313d 100644 --- a/model.cpp +++ b/model.cpp @@ -16,6 +16,7 @@ #include "ggml-cpu.h" #include "ggml.h" +#include "imatrix.hpp" #include "stable-diffusion.h" #ifdef SD_USE_METAL @@ -28,6 +29,8 @@ #define ST_HEADER_SIZE_LEN 8 +static IMatrixCollector imatrix_collector; + uint64_t read_u64(uint8_t* buffer) { // little endian uint64_t value = 0; @@ -737,7 +740,8 @@ void convert_tensor(void* src, void* dst, ggml_type dst_type, int nrows, - int n_per_row) { + int n_per_row, + std::vector imatrix = {}) { int n = nrows * n_per_row; if (src_type == dst_type) { size_t nbytes = n * ggml_type_size(src_type) / ggml_blck_size(src_type); @@ -746,7 +750,10 @@ void convert_tensor(void* src, if (dst_type == GGML_TYPE_F16) { ggml_fp32_to_fp16_row((float*)src, (ggml_fp16_t*)dst, n); } else { - std::vector imatrix(n_per_row, 1.0f); // dummy importance matrix + // if(imatrix.size() != 0){ + // LOG_INFO("using imatrix"); + // } + imatrix.resize(n_per_row, 1.0f); const float* im = imatrix.data(); ggml_quantize_chunk(dst_type, (float*)src, dst, 0, nrows, n_per_row, im); } @@ -776,7 +783,10 @@ void convert_tensor(void* src, if (dst_type == GGML_TYPE_F16) { ggml_fp32_to_fp16_row((float*)src_data_f32, (ggml_fp16_t*)dst, n); } else { - std::vector imatrix(n_per_row, 1.0f); // dummy importance matrix + // if(imatrix.size() != 0){ + // LOG_INFO("using imatrix"); + // } + imatrix.resize(n_per_row, 1.0f); const float* im = imatrix.data(); ggml_quantize_chunk(dst_type, (float*)src_data_f32, dst, 0, nrows, n_per_row, im); } @@ -1830,8 +1840,12 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } + auto processed_name = convert_tensor_name(tensor_storage.name); + // LOG_DEBUG("%s",processed_name.c_str()); + std::vector imatrix = imatrix_collector.get_values(processed_name); + convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, - dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); + dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0], imatrix); } } else { read_buffer.resize(tensor_storage.nbytes()); @@ -1853,10 +1867,14 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor)); } else { // convert first, then copy to device memory + auto processed_name = convert_tensor_name(tensor_storage.name); + // LOG_DEBUG("%s",processed_name.c_str()); + std::vector imatrix = imatrix_collector.get_values(processed_name); + convert_buffer.resize(ggml_nbytes(dst_tensor)); convert_tensor((void*)read_buffer.data(), tensor_storage.type, (void*)convert_buffer.data(), dst_tensor->type, - (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); + (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0], imatrix); ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); } } @@ -2051,12 +2069,41 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } -bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) { +bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, const char* output_path, sd_type_t output_type) { ModelLoader model_loader; - if (!model_loader.init_from_file(input_path)) { - LOG_ERROR("init model loader from file failed: '%s'", input_path); - return false; + if (model_path != NULL && strlen(model_path) > 0) { + if (!model_loader.init_from_file(model_path)) { + LOG_ERROR("init model loader from file failed: '%s'", model_path); + return false; + } + } + + if (clip_l_path != NULL && strlen(clip_l_path) > 0) { + if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { + LOG_ERROR("init model loader from file failed: '%s'", clip_l_path); + return false; + } + } + + if (clip_g_path != NULL && strlen(clip_g_path) > 0) { + if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + LOG_ERROR("init model loader from file failed: '%s'", clip_g_path); + return false; + } + } + if (t5xxl_path != NULL && strlen(t5xxl_path) > 0) { + if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) { + LOG_ERROR("init model loader from file failed: '%s'", t5xxl_path); + return false; + } + } + + if (diffusion_model_path != NULL && strlen(diffusion_model_path) > 0) { + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", diffusion_model_path); + return false; + } } if (vae_path != NULL && strlen(vae_path) > 0) { @@ -2065,6 +2112,23 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa return false; } } + bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type); return success; } + +bool loadImatrix(const char* imatrix_path) { + return imatrix_collector.load_imatrix(imatrix_path); +} +void saveImatrix(const char* imatrix_path) { + imatrix_collector.save_imatrix(imatrix_path); +} +static bool collect_imatrix(struct ggml_tensor* t, bool ask, void* user_data) { + return imatrix_collector.collect_imatrix(t, ask, user_data); +} +void enableImatrixCollection() { + sd_set_backend_eval_callback((sd_graph_eval_callback_t)collect_imatrix, NULL); +} +void disableImatrixCollection() { + sd_set_backend_eval_callback(NULL, NULL); +} \ No newline at end of file diff --git a/stable-diffusion.h b/stable-diffusion.h index 52dcc848..ea567510 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -114,9 +114,11 @@ enum sd_log_level_t { typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); +typedef bool (*sd_graph_eval_callback_t)(struct ggml_tensor * t, bool ask, void * user_data); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); +SD_API void sd_set_backend_eval_callback(sd_graph_eval_callback_t cb, void * data); SD_API int32_t get_num_physical_cores(); SD_API const char* sd_get_system_info(); @@ -228,7 +230,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); -SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type); +SD_API bool convert(const char* model_path, const char* clip_l_path, const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, const char* output_path, enum sd_type_t output_type); SD_API uint8_t* preprocess_canny(uint8_t* img, int width, @@ -239,6 +241,11 @@ SD_API uint8_t* preprocess_canny(uint8_t* img, float strong, bool inverse); +SD_API bool loadImatrix(const char * imatrix_path); +SD_API void saveImatrix(const char * imatrix_path); +SD_API void enableImatrixCollection(); +SD_API void disableImatrixCollection(); + #ifdef __cplusplus } #endif diff --git a/util.cpp b/util.cpp index da11a14d..3c177954 100644 --- a/util.cpp +++ b/util.cpp @@ -247,6 +247,9 @@ int32_t get_num_physical_cores() { static sd_progress_cb_t sd_progress_cb = NULL; void* sd_progress_cb_data = NULL; +static ggml_graph_eval_callback callback_eval = NULL; +void * callback_eval_user_data = NULL; + std::u32string utf8_to_utf32(const std::string& utf8_str) { std::wstring_convert, char32_t> converter; return converter.from_bytes(utf8_str); @@ -420,6 +423,20 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) { sd_progress_cb = cb; sd_progress_cb_data = data; } + +void sd_set_backend_eval_callback(ggml_graph_eval_callback cb, void * data){ + callback_eval = cb; + callback_eval_user_data = data; +} + +ggml_graph_eval_callback get_callback_eval(){ + return callback_eval; +} + +void* get_callback_eval_user_data() { + return callback_eval_user_data; +} + const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; diff --git a/util.h b/util.h index 14fa812e..f23cc5de 100644 --- a/util.h +++ b/util.h @@ -7,6 +7,8 @@ #include "stable-diffusion.h" +typedef bool (*ggml_graph_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); + bool ends_with(const std::string& str, const std::string& ending); bool starts_with(const std::string& str, const std::string& start); bool contains(const std::string& str, const std::string& substr); @@ -53,6 +55,8 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text); +ggml_graph_eval_callback get_callback_eval(); +void* get_callback_eval_user_data(); #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)