diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index b01ad267446fb..18ea3dcbb7bb6 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -5,8 +5,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #ifdef _WIN32 @@ -17,6 +20,7 @@ # include # include #else +# include # include # include # include @@ -82,7 +86,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of // RPC commands enum rpc_cmd { - ALLOC_BUFFER = 0, + HELLO = 0, + ALLOC_BUFFER, GET_ALIGNMENT, GET_MAX_SIZE, BUFFER_GET_BASE, @@ -91,8 +96,15 @@ enum rpc_cmd { SET_TENSOR, GET_TENSOR, COPY_TENSOR, + REMOTE_COPY_TENSOR, GRAPH_COMPUTE, GET_DEVICE_MEMORY, + FREE_ALL_BUFFERS, +}; + +enum rpc_actor { + CLIENT = 0, + SERVER, }; // RPC data structures @@ -115,6 +127,7 @@ struct ggml_backend_rpc_context { }; struct ggml_backend_rpc_buffer_context { + std::string endpoint; std::shared_ptr sock; std::unordered_map base_cache; uint64_t remote_ptr; @@ -205,7 +218,7 @@ static std::shared_ptr create_server_socket(const char * host, int por if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { return nullptr; } - if (listen(sockfd, 1) < 0) { + if (listen(sockfd, 2) < 0) { return nullptr; } return sock; @@ -276,6 +289,15 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static void send_hello(std::shared_ptr sock, rpc_actor actor) { + // input serialization format: | actor (1 byte) | + std::vector input(1, actor); + std::vector output; + bool status = send_rpc_cmd(sock, HELLO, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.empty()); +} + static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -309,6 +331,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } + send_hello(sock, CLIENT); GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); sockets[endpoint] = sock; return sock; @@ -422,6 +445,29 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b memcpy(data, output.data(), size); } +static bool remote_copy_tensor(const ggml_tensor * src, ggml_tensor * dst) { + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + // input serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) | + int input_size = 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_ctx->endpoint.size(); + std::vector input(input_size, 0); + rpc_tensor rpc_src = serialize_tensor(src); + rpc_tensor rpc_dst = serialize_tensor(dst); + memcpy(input.data(), &rpc_src, sizeof(rpc_src)); + memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); + uint32_t dst_endpoint_size = dst_ctx->endpoint.size(); + memcpy(input.data() + 2*sizeof(rpc_tensor), &dst_endpoint_size, sizeof(dst_endpoint_size)); + memcpy(input.data() + 2*sizeof(rpc_tensor) + sizeof(dst_endpoint_size), dst_ctx->endpoint.c_str(), dst_endpoint_size); + std::vector output; + bool status = send_rpc_cmd(src_ctx->sock, REMOTE_COPY_TENSOR, input, output); + GGML_ASSERT(status); + // output serialization format: | result (1 byte) | + GGML_ASSERT(output.size() == 1); + return output[0]; +} + GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { // check if src and dst are on the same server ggml_backend_buffer_t src_buffer = src->buffer; @@ -429,7 +475,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b ggml_backend_buffer_t dst_buffer = dst->buffer; ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; if (src_ctx->sock != dst_ctx->sock) { - return false; + return remote_copy_tensor(src, dst); } ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; // input serialization format: | rpc_tensor src | rpc_tensor dst | @@ -495,7 +541,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer if (remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, + new ggml_backend_rpc_buffer_context{buft_ctx->endpoint, sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, remote_size); return buffer; } else { @@ -739,6 +785,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint // RPC server-side implementation +template +class message_queue { + std::queue queue; + std::mutex mutex; + std::condition_variable cvar; + +public: + message_queue() {} + + void push(const T &value) { + std::unique_lock lock(mutex); + queue.push(value); + lock.unlock(); + cvar.notify_all(); + } + + void pop(T* out) { + std::unique_lock lock(mutex); + cvar.wait(lock, [this] { return queue.size() > 0; }); + *out = queue.front(); + queue.pop(); + } +}; + +struct rpc_response { + std::vector output; + bool status; +}; + +using rpc_response_ptr = std::shared_ptr; +using response_queue = message_queue; +using response_queue_ptr = std::shared_ptr; + +struct rpc_request { + rpc_cmd cmd; + std::vector input; + response_queue_ptr response_queue; +}; +using rpc_request_ptr = std::shared_ptr; +using request_queue = message_queue; +using request_queue_ptr = std::shared_ptr; + class rpc_server { public: rpc_server(ggml_backend_t backend) : backend(backend) {} @@ -751,10 +839,13 @@ class rpc_server { bool free_buffer(const std::vector & input); bool buffer_clear(const std::vector & input); bool set_tensor(const std::vector & input); + void remote_set_tensor(std::shared_ptr sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst); bool get_tensor(const std::vector & input, std::vector & output); bool copy_tensor(const std::vector & input, std::vector & output); + bool remote_copy_tensor(const std::vector & input, std::vector & output); bool graph_compute(const std::vector & input, std::vector & output); + void free_all_buffers(); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * create_node(uint64_t id, @@ -980,6 +1071,65 @@ bool rpc_server::copy_tensor(const std::vector & input, std::vector sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * src = deserialize_tensor(ctx, rpc_src); + size_t src_size = ggml_nbytes(src); + + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + size_t offset = 0; + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + src_size; + std::vector input(input_size, 0); + memcpy(input.data(), rpc_dst, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + ggml_backend_tensor_get(src, input.data() + sizeof(rpc_tensor) + sizeof(offset), offset, src_size); + std::vector output; + bool status = send_rpc_cmd(sock, SET_TENSOR, input, output); + GGML_ASSERT(status); + ggml_free(ctx); +} + +bool rpc_server::remote_copy_tensor(const std::vector & input, std::vector & output) { + // serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) | + if (input.size() < 2*sizeof(rpc_tensor) + sizeof(uint32_t)) { + return false; + } + const rpc_tensor * rpc_src = (const rpc_tensor *)input.data(); + const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_tensor)); + uint32_t dst_endpoint_size; + memcpy(&dst_endpoint_size, input.data() + 2*sizeof(rpc_tensor), sizeof(dst_endpoint_size)); + if (input.size() != 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_endpoint_size) { + return false; + } + // output serialization format: | result (1 byte) | + output.resize(1, 0); + + const char * dst_endpoint_ptr = (const char *)(input.data() + 2*sizeof(rpc_tensor) + sizeof(uint32_t)); + std::string dst_endpoint(dst_endpoint_ptr, dst_endpoint_size); + + std::string host; + int port; + if (!parse_endpoint(dst_endpoint, host, port)) { + output[0] = false; + return true; + } + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { + output[0] = false; + return true; + } + send_hello(sock, SERVER); + remote_set_tensor(sock, rpc_src, rpc_dst); + output.resize(1, 0); + output[0] = true; + return true; +} + ggml_tensor * rpc_server::create_node(uint64_t id, struct ggml_context * ctx, const std::unordered_map & tensor_ptrs, @@ -1051,97 +1201,211 @@ bool rpc_server::graph_compute(const std::vector & input, std::vector input; - std::vector output; - uint64_t input_size; - if (!recv_data(sockfd, &input_size, sizeof(input_size))) { - break; - } - input.resize(input_size); - if (!recv_data(sockfd, input.data(), input_size)) { - break; - } + rpc_request_ptr request; + requestq->pop(&request); + rpc_response_ptr response = std::make_shared(); bool ok = true; - switch (cmd) { + switch (request->cmd) { case ALLOC_BUFFER: { - ok = server.alloc_buffer(input, output); + ok = server.alloc_buffer(request->input, response->output); break; } case GET_ALIGNMENT: { - server.get_alignment(output); + server.get_alignment(response->output); break; } case GET_MAX_SIZE: { - server.get_max_size(output); + server.get_max_size(response->output); break; } case BUFFER_GET_BASE: { - ok = server.buffer_get_base(input, output); + ok = server.buffer_get_base(request->input, response->output); break; } case FREE_BUFFER: { - ok = server.free_buffer(input); + ok = server.free_buffer(request->input); break; } case BUFFER_CLEAR: { - ok = server.buffer_clear(input); + ok = server.buffer_clear(request->input); break; } case SET_TENSOR: { - ok = server.set_tensor(input); + ok = server.set_tensor(request->input); break; } case GET_TENSOR: { - ok = server.get_tensor(input, output); + ok = server.get_tensor(request->input, response->output); break; } case COPY_TENSOR: { - ok = server.copy_tensor(input, output); + ok = server.copy_tensor(request->input, response->output); + break; + } + case REMOTE_COPY_TENSOR: { + ok = server.remote_copy_tensor(request->input, response->output); break; } case GRAPH_COMPUTE: { - ok = server.graph_compute(input, output); + ok = server.graph_compute(request->input, response->output); + break; + } + case GET_DEVICE_MEMORY: { + break; + } + case FREE_ALL_BUFFERS: { + server.free_all_buffers(); + continue; + } + default: { + fprintf(stderr, "Unknown command: %d\n", request->cmd); + ok = false; + } + } + response->status = ok; + request->response_queue->push(response); + } +} + +static bool recv_rpc_cmd(sockfd_t sockfd, rpc_cmd & cmd, std::vector & input) { + uint8_t cmd_u8; + if (!recv_data(sockfd, &cmd_u8, 1)) { + return false; + } + cmd = (rpc_cmd)cmd_u8; + uint64_t input_size; + if (!recv_data(sockfd, &input_size, sizeof(input_size))) { + return false; + } + input.resize(input_size); + if (!recv_data(sockfd, input.data(), input_size)) { + return false; + } + return true; +} + +static void rpc_serve_client(request_queue_ptr requestq, std::shared_ptr sock, size_t free_mem, size_t total_mem) { + auto responseq = std::make_shared(); + while (true) { + auto request = std::make_shared(); + if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) { + break; + } + request->response_queue = responseq; + bool ok = true; + auto response = std::make_shared(); + switch (request->cmd) { + case ALLOC_BUFFER: + case GET_ALIGNMENT: + case GET_MAX_SIZE: + case BUFFER_GET_BASE: + case FREE_BUFFER: + case BUFFER_CLEAR: + case SET_TENSOR: + case GET_TENSOR: + case COPY_TENSOR: + case REMOTE_COPY_TENSOR: + case GRAPH_COMPUTE: { + requestq->push(request); + responseq->pop(&response); break; } case GET_DEVICE_MEMORY: { // output serialization format: | free (8 bytes) | total (8 bytes) | - output.resize(2*sizeof(uint64_t), 0); - memcpy(output.data(), &free_mem, sizeof(free_mem)); - memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); + response->output.resize(2*sizeof(uint64_t), 0); + memcpy(response->output.data(), &free_mem, sizeof(free_mem)); + memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); break; } default: { - fprintf(stderr, "Unknown command: %d\n", cmd); + fprintf(stderr, "Unexpected command: %d\n", request->cmd); ok = false; } } if (!ok) { break; } - uint64_t output_size = output.size(); - if (!send_data(sockfd, &output_size, sizeof(output_size))) { + uint64_t output_size = response->output.size(); + if (!send_data(sock->fd, &output_size, sizeof(output_size))) { break; } - if (!send_data(sockfd, output.data(), output_size)) { + if (!send_data(sock->fd, response->output.data(), output_size)) { break; } } + auto request = std::make_shared(); + request->cmd = FREE_ALL_BUFFERS; + requestq->push(request); +} + +static void rpc_serve_server(request_queue_ptr requestq, std::shared_ptr sock) { + auto responseq = std::make_shared(); + auto request = std::make_shared(); + if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) { + return; + } + if (request->cmd != SET_TENSOR) { + fprintf(stderr, "Unexpected command: %d\n", request->cmd); + return; + } + request->response_queue = responseq; + auto response = std::make_shared(); + requestq->push(request); + responseq->pop(&response); + uint64_t output_size = response->output.size(); + if (!send_data(sock->fd, &output_size, sizeof(output_size))) { + return; + } + send_data(sock->fd, response->output.data(), output_size); } +static bool recv_hello(std::shared_ptr sock, rpc_actor & actor) { + rpc_cmd cmd; + std::vector input; + if (!recv_rpc_cmd(sock->fd, cmd, input)) { + return false; + } + if (cmd != HELLO || input.size() != 1) { + return false; + } + if (input[0] != CLIENT && input[0] != SERVER) { + return false; + } + actor = (rpc_actor)input[0]; + uint64_t output_size = 0; + if (!send_data(sock->fd, &output_size, sizeof(output_size))) { + return false; + } + return true; +} + +static std::mutex client_mutex; +static std::mutex server_mutex; + void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { +#ifndef _WIN32 + // prevent SIGPIPE when writing to closed socket + signal(SIGPIPE, SIG_IGN); +#endif + auto requestq = std::make_shared(); + std::thread backend_thread = std::thread([=] { + process_requests(backend, requestq); + }); + std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1168,9 +1432,27 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free fprintf(stderr, "Failed to accept client connection\n"); return; } - printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); - rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); - printf("Client connection closed\n"); + rpc_actor actor; + if (!recv_hello(client_socket, actor)) { + continue; + } + if (actor == CLIENT) { + std::thread client_thread = std::thread([=] { + std::lock_guard lock(client_mutex); + printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + rpc_serve_client(requestq, client_socket, free_mem, total_mem); + printf("Client connection closed\n"); + }); + client_thread.detach(); + } else { + std::thread server_thread = std::thread([=] { + std::lock_guard lock(server_mutex); + printf("Accepted connection from another server\n"); + rpc_serve_server(requestq, client_socket); + printf("Server connection closed\n"); + }); + server_thread.detach(); + } } #ifdef _WIN32 WSACleanup();