Skip to content

Commit bde7cd3

Browse files
rgerganovslaren
andauthored
llama : offload to RPC in addition to other backends (ggml-org#7640)
* llama : offload to RPC in addition to other backends * - fix copy_tensor being called on the src buffer instead of the dst buffer - always initialize views in the view_src buffer - add RPC backend to Makefile build - add endpoint to all RPC object names * add rpc-server to Makefile * Update llama.cpp Co-authored-by: slaren <[email protected]> --------- Co-authored-by: slaren <[email protected]>
1 parent a5735e4 commit bde7cd3

File tree

6 files changed

+86
-53
lines changed

6 files changed

+86
-53
lines changed

Makefile

+24-5
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ ifeq ($(UNAME_S),Darwin)
6969
endif
7070
endif
7171

72+
ifdef LLAMA_RPC
73+
BUILD_TARGETS += rpc-server
74+
endif
75+
7276
default: $(BUILD_TARGETS)
7377

7478
test: $(TEST_TARGETS)
@@ -429,6 +433,11 @@ ifdef LLAMA_BLIS
429433
MK_LDFLAGS += -lblis -L/usr/local/lib
430434
endif # LLAMA_BLIS
431435

436+
ifdef LLAMA_RPC
437+
MK_CPPFLAGS += -DGGML_USE_RPC
438+
OBJS += ggml-rpc.o
439+
endif # LLAMA_RPC
440+
432441
ifdef LLAMA_CUBLAS
433442
# LLAMA_CUBLAS is deprecated and will be removed in the future
434443
LLAMA_CUDA := 1
@@ -654,11 +663,26 @@ ggml-metal-embed.o: ggml-metal.metal ggml-common.h
654663
endif
655664
endif # LLAMA_METAL
656665

666+
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
667+
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h llama.h
668+
COMMON_DEPS = common.o sampling.o grammar-parser.o build-info.o json-schema-to-grammar.o
669+
657670
ifndef LLAMA_NO_LLAMAFILE
658671
sgemm.o: sgemm.cpp sgemm.h ggml.h
659672
$(CXX) $(CXXFLAGS) -c $< -o $@
660673
endif
661674

675+
ifdef LLAMA_RPC
676+
ggml-rpc.o: ggml-rpc.cpp ggml-rpc.h
677+
$(CXX) $(CXXFLAGS) -c $< -o $@
678+
679+
rpc-server.o: examples/rpc/rpc-server.cpp ggml-rpc.h
680+
$(CXX) $(CXXFLAGS) -c $< -o $@
681+
682+
rpc-server: rpc-server.o ggml.o llama.o $(COMMON_DEPS) $(OBJS)
683+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
684+
endif # LLAMA_RPC
685+
662686
GF_CC := $(CC)
663687
include scripts/get-flags.mk
664688

@@ -738,14 +762,9 @@ unicode.o: unicode.cpp unicode.h
738762
unicode-data.o: unicode-data.cpp unicode-data.h
739763
$(CXX) $(CXXFLAGS) -c $< -o $@
740764

741-
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
742-
743765
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
744766
$(CXX) $(CXXFLAGS) -c $< -o $@
745767

746-
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h llama.h
747-
COMMON_DEPS = common.o sampling.o grammar-parser.o build-info.o json-schema-to-grammar.o
748-
749768
common.o: common/common.cpp $(COMMON_H_DEPS)
750769
$(CXX) $(CXXFLAGS) -c $< -o $@
751770

ggml-alloc.c

+3-3
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
750750
// this tensor was allocated without ggml-backend
751751
return;
752752
}
753-
ggml_backend_view_init(galloc->buffers[buffer_id], tensor);
753+
ggml_backend_view_init(tensor);
754754
}
755755
} else {
756756
if (tensor->data == NULL) {
@@ -899,12 +899,12 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
899899
if (t->view_src == NULL) {
900900
ggml_tallocr_alloc(&tallocr, t);
901901
} else if (t->buffer == NULL) {
902-
ggml_backend_view_init(buffer, t);
902+
ggml_backend_view_init(t);
903903
}
904904
} else {
905905
if (t->view_src != NULL && t->buffer == NULL) {
906906
// view of a pre-allocated tensor
907-
ggml_backend_view_init(buffer, t);
907+
ggml_backend_view_init(t);
908908
}
909909
}
910910
}

ggml-backend.c

+5-5
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
151151
bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
152152
ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
153153
if (dst_buf->iface.cpy_tensor) {
154-
return src->buffer->iface.cpy_tensor(dst_buf, src, dst);
154+
return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
155155
}
156156
return false;
157157
}
@@ -1887,15 +1887,15 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
18871887

18881888
// utils
18891889

1890-
void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
1890+
void ggml_backend_view_init(struct ggml_tensor * tensor) {
18911891
GGML_ASSERT(tensor->buffer == NULL);
18921892
GGML_ASSERT(tensor->view_src != NULL);
18931893
GGML_ASSERT(tensor->view_src->buffer != NULL);
18941894
GGML_ASSERT(tensor->view_src->data != NULL);
18951895

1896-
tensor->buffer = buffer;
1896+
tensor->buffer = tensor->view_src->buffer;
18971897
tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1898-
ggml_backend_buffer_init_tensor(buffer, tensor);
1898+
ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
18991899
}
19001900

19011901
void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
@@ -1954,7 +1954,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te
19541954
struct ggml_tensor * dst = node_copies[id];
19551955
if (dst->view_src != NULL) {
19561956
graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
1957-
ggml_backend_view_init(dst->view_src->buffer, dst);
1957+
ggml_backend_view_init(dst);
19581958
}
19591959
else {
19601960
ggml_backend_tensor_copy(src, dst);

ggml-backend.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ extern "C" {
225225

226226
// Tensor initialization
227227
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
228-
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
228+
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
229229

230230

231231
#ifdef __cplusplus

ggml-rpc.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
491491
if (remote_ptr != 0) {
492492
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493493
ggml_backend_rpc_buffer_interface,
494-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
494+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
495495
remote_size);
496496
return buffer;
497497
} else {
@@ -692,7 +692,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
692692
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
693693
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
694694
/* .endpoint = */ endpoint,
695-
/* .name = */ "RPC",
695+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
696696
};
697697

698698
ggml_backend_t backend = new ggml_backend {

llama.cpp

+51-37
Original file line numberDiff line numberDiff line change
@@ -2371,13 +2371,34 @@ struct llama_context {
23712371
struct llama_control_vector cvec;
23722372
};
23732373

2374+
static size_t llama_get_device_count(const llama_model & model) {
2375+
size_t count = 1;
2376+
#if defined(GGML_USE_CUDA)
2377+
count = ggml_backend_cuda_get_device_count();
2378+
#elif defined(GGML_USE_SYCL)
2379+
count = ggml_backend_sycl_get_device_count();
2380+
#elif defined(GGML_USE_VULKAN)
2381+
count = ggml_backend_vk_get_device_count();
2382+
#endif
2383+
#if defined(GGML_USE_RPC)
2384+
count += model.rpc_servers.size();
2385+
#endif
2386+
return count;
2387+
GGML_UNUSED(model);
2388+
}
2389+
23742390
static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
23752391
ggml_backend_buffer_type_t buft = nullptr;
23762392

2377-
#ifdef GGML_USE_RPC
2378-
std::string endpoint = model.rpc_servers[gpu];
2379-
buft = ggml_backend_rpc_buffer_type(endpoint.c_str());
2380-
#elif defined(GGML_USE_METAL)
2393+
#if defined(GGML_USE_RPC)
2394+
int dev_count = (int)llama_get_device_count(model);
2395+
int rpc_count = (int)model.rpc_servers.size();
2396+
if (gpu >= dev_count - rpc_count) {
2397+
const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str();
2398+
return ggml_backend_rpc_buffer_type(endpoint);
2399+
}
2400+
#endif
2401+
#if defined(GGML_USE_METAL)
23812402
buft = ggml_backend_metal_buffer_type();
23822403
#elif defined(GGML_USE_CUDA)
23832404
buft = ggml_backend_cuda_buffer_type(gpu);
@@ -2425,29 +2446,19 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
24252446
GGML_UNUSED(tensor_split);
24262447
}
24272448

2428-
static size_t llama_get_device_count(const llama_model & model) {
2429-
#if defined(GGML_USE_RPC)
2430-
return model.rpc_servers.size();
2431-
#elif defined(GGML_USE_CUDA)
2432-
return ggml_backend_cuda_get_device_count();
2433-
#elif defined(GGML_USE_SYCL)
2434-
return ggml_backend_sycl_get_device_count();
2435-
#elif defined(GGML_USE_VULKAN)
2436-
return ggml_backend_vk_get_device_count();
2437-
#else
2438-
return 1;
2439-
#endif
2440-
GGML_UNUSED(model);
2441-
}
2442-
24432449
static size_t llama_get_device_memory(const llama_model & model, int device) {
24442450
#if defined(GGML_USE_RPC)
2445-
size_t total;
2446-
size_t free;
2447-
std::string endpoint = model.rpc_servers[device];
2448-
ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total);
2449-
return free;
2450-
#elif defined(GGML_USE_CUDA)
2451+
int dev_count = (int)llama_get_device_count(model);
2452+
int rpc_count = (int)model.rpc_servers.size();
2453+
if (device >= dev_count - rpc_count) {
2454+
size_t total;
2455+
size_t free;
2456+
const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str();
2457+
ggml_backend_rpc_get_device_memory(endpoint, &free, &total);
2458+
return free;
2459+
}
2460+
#endif
2461+
#if defined(GGML_USE_CUDA)
24512462
size_t total;
24522463
size_t free;
24532464
ggml_backend_cuda_get_device_memory(device, &free, &total);
@@ -16160,7 +16171,7 @@ struct llama_model * llama_load_model_from_file(
1616016171
return true;
1616116172
};
1616216173
}
16163-
if (params.rpc_servers != nullptr) {
16174+
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
1616416175
// split the servers set them into model->rpc_servers
1616516176
std::string servers(params.rpc_servers);
1616616177
size_t pos = 0;
@@ -16323,17 +16334,7 @@ struct llama_context * llama_new_context_with_model(
1632316334

1632416335
if (!hparams.vocab_only) {
1632516336
// initialize backends
16326-
#if defined(GGML_USE_RPC)
16327-
for (auto & server : model->rpc_servers) {
16328-
ggml_backend_t backend = ggml_backend_rpc_init(server.c_str());
16329-
if (backend == nullptr) {
16330-
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
16331-
llama_free(ctx);
16332-
return nullptr;
16333-
}
16334-
ctx->backends.push_back(backend);
16335-
}
16336-
#elif defined(GGML_USE_METAL)
16337+
#if defined(GGML_USE_METAL)
1633716338
if (model->n_gpu_layers > 0) {
1633816339
ctx->backend_metal = ggml_backend_metal_init();
1633916340
if (ctx->backend_metal == nullptr) {
@@ -16425,6 +16426,19 @@ struct llama_context * llama_new_context_with_model(
1642516426
}
1642616427
ctx->backends.push_back(backend);
1642716428
}
16429+
#endif
16430+
#if defined(GGML_USE_RPC)
16431+
if (model->n_gpu_layers > 0) {
16432+
for (const auto & endpoint : model->rpc_servers) {
16433+
ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
16434+
if (backend == nullptr) {
16435+
LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
16436+
llama_free(ctx);
16437+
return nullptr;
16438+
}
16439+
ctx->backends.push_back(backend);
16440+
}
16441+
}
1642816442
#endif
1642916443
ctx->backend_cpu = ggml_backend_cpu_init();
1643016444
if (ctx->backend_cpu == nullptr) {

0 commit comments

Comments
 (0)