Skip to content

rpc : enable async operations #7915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 139 additions & 33 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#include <cinttypes>
#include <string>
#include <vector>
#include <queue>
#include <memory>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <unordered_map>
#include <unordered_set>
#ifdef _WIN32
Expand All @@ -17,6 +20,7 @@
# include <windows.h>
# include <winsock2.h>
#else
# include <signal.h>
# include <arpa/inet.h>
# include <sys/socket.h>
# include <sys/types.h>
Expand Down Expand Up @@ -89,6 +93,7 @@ enum rpc_cmd {
COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
FREE_ALL_BUFFERS,
};

// RPC data structures
Expand Down Expand Up @@ -736,6 +741,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint

// RPC server-side implementation

template <typename T>
class message_queue {
std::queue<T> queue;
std::mutex mutex;
std::condition_variable cvar;

public:
message_queue() {}

void push(const T &value) {
std::unique_lock<std::mutex> lock(mutex);
queue.push(value);
lock.unlock();
cvar.notify_all();
}

void pop(T* out) {
std::unique_lock<std::mutex> lock(mutex);
cvar.wait(lock, [this] { return queue.size() > 0; });
*out = queue.front();
queue.pop();
}
};

struct rpc_response {
std::vector<uint8_t> output;
bool status;
};

using rpc_response_ptr = std::shared_ptr<rpc_response>;
using response_queue = message_queue<rpc_response_ptr>;
using response_queue_ptr = std::shared_ptr<response_queue>;

struct rpc_request {
rpc_cmd cmd;
std::vector<uint8_t> input;
response_queue_ptr response_queue;
};
using rpc_request_ptr = std::shared_ptr<rpc_request>;
using request_queue = message_queue<rpc_request_ptr>;
using request_queue_ptr = std::shared_ptr<request_queue>;

class rpc_server {
public:
rpc_server(ggml_backend_t backend) : backend(backend) {}
Expand All @@ -752,6 +799,7 @@ class rpc_server {
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);

void free_all_buffers();
private:
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id,
Expand Down Expand Up @@ -1046,76 +1094,122 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
return true;
}

rpc_server::~rpc_server() {
void rpc_server::free_all_buffers() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
}
buffers.clear();
}

rpc_server::~rpc_server() {
free_all_buffers();
}

static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
static void process_requests(ggml_backend_t backend, request_queue_ptr requestq) {
rpc_server server(backend);
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
std::vector<uint8_t> input;
std::vector<uint8_t> output;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
input.resize(input_size);
if (!recv_data(sockfd, input.data(), input_size)) {
break;
}
rpc_request_ptr request;
requestq->pop(&request);
rpc_response_ptr response = std::make_shared<rpc_response>();
bool ok = true;
switch (cmd) {
switch (request->cmd) {
case ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output);
ok = server.alloc_buffer(request->input, response->output);
break;
}
case GET_ALIGNMENT: {
server.get_alignment(output);
server.get_alignment(response->output);
break;
}
case GET_MAX_SIZE: {
server.get_max_size(output);
server.get_max_size(response->output);
break;
}
case BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output);
ok = server.buffer_get_base(request->input, response->output);
break;
}
case FREE_BUFFER: {
ok = server.free_buffer(input);
ok = server.free_buffer(request->input);
break;
}
case BUFFER_CLEAR: {
ok = server.buffer_clear(input);
ok = server.buffer_clear(request->input);
break;
}
case SET_TENSOR: {
ok = server.set_tensor(input);
ok = server.set_tensor(request->input);
break;
}
case GET_TENSOR: {
ok = server.get_tensor(input, output);
ok = server.get_tensor(request->input, response->output);
break;
}
case COPY_TENSOR: {
ok = server.copy_tensor(input, output);
ok = server.copy_tensor(request->input, response->output);
break;
}
case GRAPH_COMPUTE: {
ok = server.graph_compute(input, output);
ok = server.graph_compute(request->input, response->output);
break;
}
case GET_DEVICE_MEMORY: {
break;
}
case FREE_ALL_BUFFERS: {
server.free_all_buffers();
continue;
}
default: {
fprintf(stderr, "Unknown command: %d\n", request->cmd);
ok = false;
}
}
response->status = ok;
request->response_queue->push(response);
}
}

static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
auto responseq = std::make_shared<response_queue>();
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
auto request = std::make_shared<rpc_request>();
request->cmd = (rpc_cmd)cmd;
request->response_queue = responseq;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
request->input.resize(input_size);
if (!recv_data(sockfd, request->input.data(), input_size)) {
break;
}
bool ok = true;
auto response = std::make_shared<rpc_response>();
switch (cmd) {
case ALLOC_BUFFER:
case GET_ALIGNMENT:
case GET_MAX_SIZE:
case BUFFER_GET_BASE:
case FREE_BUFFER:
case BUFFER_CLEAR:
case SET_TENSOR:
case GET_TENSOR:
case COPY_TENSOR:
case GRAPH_COMPUTE: {
requestq->push(request);
responseq->pop(&response);
break;
}
case GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
response->output.resize(2*sizeof(uint64_t), 0);
memcpy(response->output.data(), &free_mem, sizeof(free_mem));
memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
break;
}
default: {
Expand All @@ -1126,17 +1220,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
if (!ok) {
break;
}
uint64_t output_size = output.size();
uint64_t output_size = response->output.size();
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
break;
}
if (!send_data(sockfd, output.data(), output_size)) {
if (!send_data(sockfd, response->output.data(), output_size)) {
break;
}
}
auto request = std::make_shared<rpc_request>();
request->cmd = FREE_ALL_BUFFERS;
requestq->push(request);
}

void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
#ifndef _WIN32
// prevent SIGPIPE when writing to closed socket
signal(SIGPIPE, SIG_IGN);
#endif
auto requestq = std::make_shared<request_queue>();
std::thread backend_thread = std::thread([=] {
process_requests(backend, requestq);
});

std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
Expand Down Expand Up @@ -1164,7 +1270,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
}
#ifdef _WIN32
Expand Down
Loading