Skip to content

Commit 34d40a9

Browse files
rgerganovmglambda
authored andcommitted
rpc : early register backend devices (ggml-org#11262)
Early register RPC devices and do not propagate RPC specifics in the llama model structures. ref: ggml-org#10609
1 parent fc33494 commit 34d40a9

File tree

10 files changed

+61
-55
lines changed

10 files changed

+61
-55
lines changed

common/arg.cpp

+26-1
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,30 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
376376
return devices;
377377
}
378378

379+
static void add_rpc_devices(std::string servers) {
380+
auto rpc_servers = string_split<std::string>(servers, ',');
381+
if (rpc_servers.empty()) {
382+
throw std::invalid_argument("no RPC servers specified");
383+
}
384+
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
385+
if (!rpc_reg) {
386+
throw std::invalid_argument("failed to find RPC backend");
387+
}
388+
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
389+
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
390+
if (!ggml_backend_rpc_add_device_fn) {
391+
throw std::invalid_argument("failed to find RPC device add function");
392+
}
393+
for (const auto & server : rpc_servers) {
394+
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
395+
if (dev) {
396+
ggml_backend_device_register(dev);
397+
} else {
398+
throw std::invalid_argument("failed to register RPC device");
399+
}
400+
}
401+
}
402+
379403
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
380404
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
381405
const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -1385,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13851409
{"--rpc"}, "SERVERS",
13861410
"comma separated list of RPC servers",
13871411
[](common_params & params, const std::string & value) {
1388-
params.rpc_servers = value;
1412+
add_rpc_devices(value);
1413+
GGML_UNUSED(params);
13891414
}
13901415
).set_env("LLAMA_ARG_RPC"));
13911416
}

common/common.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
10431043
if (params.n_gpu_layers != -1) {
10441044
mparams.n_gpu_layers = params.n_gpu_layers;
10451045
}
1046-
mparams.rpc_servers = params.rpc_servers.c_str();
10471046
mparams.main_gpu = params.main_gpu;
10481047
mparams.split_mode = params.split_mode;
10491048
mparams.tensor_split = params.tensor_split;

common/common.h

-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ struct common_params {
246246
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
247247
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
248248
std::string logits_file = ""; // file for saving *all* logits // NOLINT
249-
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
250249

251250
std::vector<std::string> in_files; // all input files
252251
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

examples/llama-bench/llama-bench.cpp

+33-4
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ struct cmd_params_instance {
683683
bool cpu_strict;
684684
int poll;
685685
int n_gpu_layers;
686-
std::string rpc_servers;
686+
std::string rpc_servers_str;
687687
llama_split_mode split_mode;
688688
int main_gpu;
689689
bool no_kv_offload;
@@ -696,8 +696,37 @@ struct cmd_params_instance {
696696
llama_model_params mparams = llama_model_default_params();
697697

698698
mparams.n_gpu_layers = n_gpu_layers;
699-
if (!rpc_servers.empty()) {
700-
mparams.rpc_servers = rpc_servers.c_str();
699+
if (!rpc_servers_str.empty()) {
700+
auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
701+
702+
// add RPC devices
703+
if (!rpc_servers.empty()) {
704+
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
705+
if (!rpc_reg) {
706+
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
707+
exit(1);
708+
}
709+
710+
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
711+
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
712+
if (!ggml_backend_rpc_add_device_fn) {
713+
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
714+
exit(1);
715+
}
716+
static std::vector<ggml_backend_dev_t> devices;
717+
devices.clear();
718+
for (const std::string & server : rpc_servers) {
719+
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
720+
if (dev) {
721+
devices.push_back(dev);
722+
} else {
723+
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
724+
exit(1);
725+
}
726+
}
727+
devices.push_back(nullptr);
728+
mparams.devices = devices.data();
729+
}
701730
}
702731
mparams.split_mode = split_mode;
703732
mparams.main_gpu = main_gpu;
@@ -708,7 +737,7 @@ struct cmd_params_instance {
708737
}
709738

710739
bool equal_mparams(const cmd_params_instance & other) const {
711-
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers &&
740+
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
712741
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
713742
tensor_split == other.tensor_split;
714743
}

ggml/include/ggml-backend.h

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ extern "C" {
203203
// Backend registry
204204
//
205205

206+
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
207+
206208
// Backend (reg) enumeration
207209
GGML_API size_t ggml_backend_reg_count(void);
208210
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);

ggml/src/ggml-backend-impl.h

-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ extern "C" {
208208

209209
// Internal backend registry API
210210
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
211-
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
212211

213212
// Add backend dynamic loading support to the backend
214213

include/llama.h

-3
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,6 @@ extern "C" {
288288
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
289289
const float * tensor_split;
290290

291-
// comma separated list of RPC servers to use for offloading
292-
const char * rpc_servers;
293-
294291
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
295292
// If the provided progress_callback returns true, model loading continues.
296293
// If it returns false, model loading is immediately aborted.

src/llama-model.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -3717,7 +3717,6 @@ struct llama_model_params llama_model_default_params() {
37173717
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
37183718
/*.main_gpu =*/ 0,
37193719
/*.tensor_split =*/ nullptr,
3720-
/*.rpc_servers =*/ nullptr,
37213720
/*.progress_callback =*/ nullptr,
37223721
/*.progress_callback_user_data =*/ nullptr,
37233722
/*.kv_overrides =*/ nullptr,

src/llama-model.h

-2
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,6 @@ struct llama_model {
323323
// gguf metadata
324324
std::unordered_map<std::string, std::string> gguf_kv;
325325

326-
std::vector<std::string> rpc_servers;
327-
328326
// list of devices used in this model
329327
std::vector<ggml_backend_dev_t> devices;
330328

src/llama.cpp

-41
Original file line numberDiff line numberDiff line change
@@ -9399,47 +9399,6 @@ static struct llama_model * llama_model_load_from_file_impl(
93999399
};
94009400
}
94019401

9402-
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
9403-
// split the servers set them into model->rpc_servers
9404-
std::string servers(params.rpc_servers);
9405-
size_t pos = 0;
9406-
while ((pos = servers.find(',')) != std::string::npos) {
9407-
std::string server = servers.substr(0, pos);
9408-
model->rpc_servers.push_back(server);
9409-
servers.erase(0, pos + 1);
9410-
}
9411-
model->rpc_servers.push_back(servers);
9412-
}
9413-
9414-
// add RPC devices
9415-
if (!model->rpc_servers.empty()) {
9416-
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
9417-
if (!rpc_reg) {
9418-
LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
9419-
llama_model_free(model);
9420-
return nullptr;
9421-
}
9422-
9423-
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
9424-
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
9425-
if (!ggml_backend_rpc_add_device_fn) {
9426-
LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
9427-
llama_model_free(model);
9428-
return nullptr;
9429-
}
9430-
9431-
for (const std::string & server : model->rpc_servers) {
9432-
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
9433-
if (dev) {
9434-
model->devices.push_back(dev);
9435-
} else {
9436-
LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
9437-
llama_model_free(model);
9438-
return nullptr;
9439-
}
9440-
}
9441-
}
9442-
94439402
// create list of devices to use with this model
94449403
if (params.devices) {
94459404
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {

0 commit comments

Comments
 (0)