Skip to content

Commit e0e912f

Browse files
authored
llama : add option to override model tensor buffers (#11397)
* llama : add option to override tensor buffers * ggml : fix possible underflow in ggml_nbytes
1 parent a10b36c commit e0e912f

12 files changed

+108
-9
lines changed

common/arg.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "gguf.h" // for reading GGUF splits
22
#include "arg.h"
33

4+
#include "common.h"
45
#include "log.h"
56
#include "sampling.h"
67
#include "chat.h"
@@ -848,6 +849,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
848849
params.kv_overrides.back().key[0] = 0;
849850
}
850851

852+
if (!params.tensor_buft_overrides.empty()) {
853+
params.tensor_buft_overrides.push_back({nullptr, nullptr});
854+
}
855+
851856
if (params.reranking && params.embedding) {
852857
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
853858
}
@@ -2180,6 +2185,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21802185
exit(0);
21812186
}
21822187
));
2188+
add_opt(common_arg(
2189+
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
2190+
"override tensor buffer type", [](common_params & params, const std::string & value) {
2191+
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
2192+
if (buft_list.empty()) {
2193+
// enumerate all the devices and add their buffer types to the list
2194+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
2195+
auto * dev = ggml_backend_dev_get(i);
2196+
auto * buft = ggml_backend_dev_buffer_type(dev);
2197+
if (buft) {
2198+
buft_list[ggml_backend_buft_name(buft)] = buft;
2199+
}
2200+
}
2201+
}
2202+
2203+
for (const auto & override : string_split<std::string>(value, ',')) {
2204+
std::string::size_type pos = override.find('=');
2205+
if (pos == std::string::npos) {
2206+
throw std::invalid_argument("invalid value");
2207+
}
2208+
std::string tensor_name = override.substr(0, pos);
2209+
std::string buffer_type = override.substr(pos + 1);
2210+
2211+
if (buft_list.find(buffer_type) == buft_list.end()) {
2212+
printf("Available buffer types:\n");
2213+
for (const auto & it : buft_list) {
2214+
printf(" %s\n", ggml_backend_buft_name(it.second));
2215+
}
2216+
throw std::invalid_argument("unknown buffer type");
2217+
}
2218+
// FIXME: this leaks memory
2219+
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
2220+
}
2221+
}
2222+
));
21832223
add_opt(common_arg(
21842224
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
21852225
"number of layers to store in VRAM",

common/common.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -1042,22 +1042,32 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
10421042
if (!params.devices.empty()) {
10431043
mparams.devices = params.devices.data();
10441044
}
1045+
10451046
if (params.n_gpu_layers != -1) {
10461047
mparams.n_gpu_layers = params.n_gpu_layers;
10471048
}
1049+
10481050
mparams.main_gpu = params.main_gpu;
10491051
mparams.split_mode = params.split_mode;
10501052
mparams.tensor_split = params.tensor_split;
10511053
mparams.use_mmap = params.use_mmap;
10521054
mparams.use_mlock = params.use_mlock;
10531055
mparams.check_tensors = params.check_tensors;
1056+
10541057
if (params.kv_overrides.empty()) {
10551058
mparams.kv_overrides = NULL;
10561059
} else {
10571060
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
10581061
mparams.kv_overrides = params.kv_overrides.data();
10591062
}
10601063

1064+
if (params.tensor_buft_overrides.empty()) {
1065+
mparams.tensor_buft_overrides = NULL;
1066+
} else {
1067+
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
1068+
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
1069+
}
1070+
10611071
return mparams;
10621072
}
10631073

common/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ struct common_params {
279279
std::vector<std::string> in_files; // all input files
280280
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
281281
std::vector<llama_model_kv_override> kv_overrides;
282+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
282283

283284
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
284285
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale

ggml/src/ggml.c

+6
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
11591159
}
11601160

11611161
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1162+
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1163+
if (tensor->ne[i] <= 0) {
1164+
return 0;
1165+
}
1166+
}
1167+
11621168
size_t nbytes;
11631169
const size_t blck_size = ggml_blck_size(tensor->type);
11641170
if (blck_size == 1) {

include/llama.h

+8
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,18 @@ extern "C" {
280280
};
281281
};
282282

283+
struct llama_model_tensor_buft_override {
284+
const char * pattern;
285+
ggml_backend_buffer_type_t buft;
286+
};
287+
283288
struct llama_model_params {
284289
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
285290
ggml_backend_dev_t * devices;
286291

292+
// NULL-terminated list of buffer types to use for tensors that match a pattern
293+
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
294+
287295
int32_t n_gpu_layers; // number of layers to store in VRAM
288296
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
289297

src/llama-context.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ llama_context::llama_context(
255255
model.n_devices() > 1 &&
256256
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
257257
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
258-
cparams.offload_kqv;
258+
cparams.offload_kqv &&
259+
!model.has_tensor_overrides();
259260

260261
// pipeline parallelism requires support for async compute and events in all devices
261262
if (pipeline_parallel) {

src/llama-model-loader.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
445445
std::vector<std::string> & splits,
446446
bool use_mmap,
447447
bool check_tensors,
448-
const struct llama_model_kv_override * param_overrides_p) {
448+
const llama_model_kv_override * param_overrides_p,
449+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
449450
int trace = 0;
450451
if (getenv("LLAMA_TRACE")) {
451452
trace = atoi(getenv("LLAMA_TRACE"));
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
457458
}
458459
}
459460

461+
tensor_buft_overrides = param_tensor_buft_overrides_p;
462+
460463
// Load the main GGUF
461464
struct ggml_context * ctx = NULL;
462465
struct gguf_init_params params = {

src/llama-model-loader.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ struct llama_model_loader {
7777

7878
llama_mmaps mappings;
7979

80-
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
81-
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
80+
std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
81+
std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
82+
const llama_model_tensor_buft_override * tensor_buft_overrides;
8283

8384
gguf_context_ptr meta;
8485
std::vector<ggml_context_ptr> contexts;
@@ -95,7 +96,8 @@ struct llama_model_loader {
9596
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
9697
bool use_mmap,
9798
bool check_tensors,
98-
const struct llama_model_kv_override * param_overrides_p);
99+
const llama_model_kv_override * param_overrides_p,
100+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
99101

100102
template<typename T>
101103
typename std::enable_if<std::is_integral<T>::value, bool>::type

src/llama-model.cpp

+28-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cmath>
1818
#include <functional>
1919
#include <map>
20+
#include <regex>
2021
#include <sstream>
2122
#include <stdexcept>
2223

@@ -378,9 +379,12 @@ struct llama_model::impl {
378379
layer_dev dev_input = {};
379380
layer_dev dev_output = {};
380381
std::vector<layer_dev> dev_layer;
382+
383+
bool has_tensor_overrides;
381384
};
382385

383386
llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
387+
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
384388
}
385389

386390
llama_model::~llama_model() {}
@@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
15711575
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
15721576
}
15731577

1574-
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
1578+
ggml_backend_buffer_type_t buft = nullptr;
1579+
1580+
// check overrides
1581+
if (ml.tensor_buft_overrides) {
1582+
std::string tensor_name = tn.str();
1583+
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
1584+
std::regex pattern(overrides->pattern);
1585+
if (std::regex_search(tensor_name, pattern)) {
1586+
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
1587+
buft = overrides->buft;
1588+
break;
1589+
}
1590+
}
1591+
}
1592+
15751593
if (!buft) {
1576-
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
1594+
buft = select_weight_buft(hparams, t_meta, op, *buft_list);
1595+
if (!buft) {
1596+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
1597+
}
15771598
}
15781599

15791600
// avoid using a host buffer when using mmap
@@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
41514172
});
41524173
}
41534174

4175+
bool llama_model::has_tensor_overrides() const {
4176+
return pimpl->has_tensor_overrides;
4177+
}
4178+
41544179
const ggml_tensor * llama_model::get_tensor(const char * name) const {
41554180
auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
41564181
[name](const std::pair<std::string, ggml_tensor *> & it) {
@@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph(
1231912344
llama_model_params llama_model_default_params() {
1232012345
llama_model_params result = {
1232112346
/*.devices =*/ nullptr,
12347+
/*.tensor_buft_overrides =*/ nullptr,
1232212348
/*.n_gpu_layers =*/ 0,
1232312349
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
1232412350
/*.main_gpu =*/ 0,

src/llama-model.h

+2
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ struct llama_model {
382382

383383
ggml_backend_buffer_type_t select_buft(int il) const;
384384

385+
bool has_tensor_overrides() const;
386+
385387
const struct ggml_tensor * get_tensor(const char * name) const;
386388

387389
// TODO: move this to new llm_arch_model_i interface

src/llama-quant.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
527527
}
528528

529529
std::vector<std::string> splits = {};
530-
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
530+
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
531531
ml.init_mappings(false); // no prefetching
532532

533533
llama_model model(llama_model_default_params());

src/llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
9292
model.t_start_us = tm.t_start_us;
9393

9494
try {
95-
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
95+
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
9696

9797
ml.print_info();
9898

0 commit comments

Comments
 (0)