|
17 | 17 | #include <cmath>
|
18 | 18 | #include <functional>
|
19 | 19 | #include <map>
|
| 20 | +#include <regex> |
20 | 21 | #include <sstream>
|
21 | 22 | #include <stdexcept>
|
22 | 23 |
|
@@ -378,9 +379,12 @@ struct llama_model::impl {
|
378 | 379 | layer_dev dev_input = {};
|
379 | 380 | layer_dev dev_output = {};
|
380 | 381 | std::vector<layer_dev> dev_layer;
|
| 382 | + |
| 383 | + bool has_tensor_overrides; |
381 | 384 | };
|
382 | 385 |
|
383 | 386 | llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
|
| 387 | + pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; |
384 | 388 | }
|
385 | 389 |
|
386 | 390 | llama_model::~llama_model() {}
|
@@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
1571 | 1575 | GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
|
1572 | 1576 | }
|
1573 | 1577 |
|
1574 |
| - ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); |
| 1578 | + ggml_backend_buffer_type_t buft = nullptr; |
| 1579 | + |
| 1580 | + // check overrides |
| 1581 | + if (ml.tensor_buft_overrides) { |
| 1582 | + std::string tensor_name = tn.str(); |
| 1583 | + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { |
| 1584 | + std::regex pattern(overrides->pattern); |
| 1585 | + if (std::regex_search(tensor_name, pattern)) { |
| 1586 | + LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); |
| 1587 | + buft = overrides->buft; |
| 1588 | + break; |
| 1589 | + } |
| 1590 | + } |
| 1591 | + } |
| 1592 | + |
1575 | 1593 | if (!buft) {
|
1576 |
| - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); |
| 1594 | + buft = select_weight_buft(hparams, t_meta, op, *buft_list); |
| 1595 | + if (!buft) { |
| 1596 | + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); |
| 1597 | + } |
1577 | 1598 | }
|
1578 | 1599 |
|
1579 | 1600 | // avoid using a host buffer when using mmap
|
@@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
|
4151 | 4172 | });
|
4152 | 4173 | }
|
4153 | 4174 |
|
| 4175 | +bool llama_model::has_tensor_overrides() const { |
| 4176 | + return pimpl->has_tensor_overrides; |
| 4177 | +} |
| 4178 | + |
4154 | 4179 | const ggml_tensor * llama_model::get_tensor(const char * name) const {
|
4155 | 4180 | auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
|
4156 | 4181 | [name](const std::pair<std::string, ggml_tensor *> & it) {
|
@@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
12319 | 12344 | llama_model_params llama_model_default_params() {
|
12320 | 12345 | llama_model_params result = {
|
12321 | 12346 | /*.devices =*/ nullptr,
|
| 12347 | + /*.tensor_buft_overrides =*/ nullptr, |
12322 | 12348 | /*.n_gpu_layers =*/ 0,
|
12323 | 12349 | /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
|
12324 | 12350 | /*.main_gpu =*/ 0,
|
|
0 commit comments