Skip to content

Commit c663305

Browse files
committed
Implement TCP server mode.
This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `llama_main` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha <[email protected]>
1 parent 4b64181 commit c663305

9 files changed

+305
-4
lines changed

CMakeLists.txt

+9-2
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,17 @@ endif()
106106
# set(LLAMA_EXTRA_FLAGS ${LLAMA_EXTRA_FLAGS} -DGGML_PERF)
107107
# endif()
108108

109-
add_executable(llama
109+
set(LLAMA_SRC
110110
main.cpp
111111
utils.cpp
112-
llama.cpp
112+
llama.cpp)
113+
114+
if(NOT WIN32)
115+
set(LLAMA_SRC ${LLAMA_SRC} tcp_server.cpp)
116+
endif()
117+
118+
add_executable(llama
119+
${LLAMA_SRC}
113120
utils.h)
114121

115122
add_executable(quantize

Makefile

+5-2
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,14 @@ utils.o: utils.cpp utils.h
191191
llama.o: llama.cpp llama.h
192192
$(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o
193193

194+
tcp_server.o: tcp_server.cpp tcp_server.h
195+
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o
196+
194197
clean:
195198
rm -f *.o main quantize
196199

197-
main: main.cpp ggml.o utils.o llama.o
198-
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o -o main $(LDFLAGS)
200+
main: main.cpp ggml.o utils.o llama.o tcp_server.o
201+
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o tcp_server.o -o main $(LDFLAGS)
199202
./main -h
200203

201204
quantize: quantize.cpp ggml.o utils.o

chat_tcp_client.sh

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env bash
2+
3+
PORT=${PORT:-8080}
4+
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
5+
6+
User: Hello, Bob.
7+
Bob: Hello. How may I help you today?
8+
User: Please tell me the largest city in Europe.
9+
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
10+
User:"}"
11+
RPROMPT="${RPROMPT:-"User:"}"
12+
N_PREDICT="${N_PREDICT:-"4096"}"
13+
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"
14+
15+
# Open connection to the chat server
16+
exec 3<>/dev/tcp/127.0.0.1/${PORT}
17+
18+
# Pass the arguments. The protocol is really simple:
19+
# 1. Pass the number of arguments followed by a linefeed
20+
# 2. Pass the arguments, with each being followed by "0"
21+
(
22+
echo -en "10\n"
23+
echo -en "-n\x00"
24+
echo -en "$N_PREDICT\x00"
25+
echo -en "--repeat_penalty\x00"
26+
echo -en "$REPEAT_PENALTY\x00"
27+
echo -en "--color\x00"
28+
echo -en "-i\x00"
29+
echo -en "-r\x00"
30+
echo -en "$RPROMPT\x00"
31+
echo -en "-p\x00"
32+
echo -en "$PROMPT\x00"
33+
) >&3
34+
35+
trap exit TERM
36+
37+
# When we have passed the arguments, start printing socket data to the screen.
38+
# This is done in a background job because we also want to send data when
39+
# running in interactive mode.
40+
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
41+
cat >&3
42+
wait

chat_tcp_server.sh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env bash
2+
3+
PORT=${PORT:-8080}
4+
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}
5+
6+
./main -l ${PORT} -m $MODEL

main.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "ggml.h"
22
#include "utils.h"
33
#include "llama.h"
4+
#include "tcp_server.h"
45

56
#include <iostream>
67

@@ -65,5 +66,11 @@ int main(int argc, char ** argv) {
6566
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
6667
}
6768

69+
#ifndef _WIN32
70+
if (params.listen_port != "") {
71+
return listen_tcp(params, vocab, model, t_main_start_us, t_load_us);
72+
}
73+
#endif
74+
6875
return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
6976
}

tcp_server.cpp

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#include "tcp_server.h"
2+
3+
#include <ext/stdio_filebuf.h>
4+
#include <iostream>
5+
#include <fstream>
6+
7+
#include <stdarg.h>
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
#include <stdbool.h>
11+
#include <string.h>
12+
#include <errno.h>
13+
14+
#include <signal.h>
15+
#include <unistd.h>
16+
#include <sys/wait.h>
17+
18+
#include <sys/types.h>
19+
#include <sys/socket.h>
20+
#include <netdb.h>
21+
22+
void die(const char *msg, ...)
23+
{
24+
va_list ap;
25+
26+
va_start(ap, msg);
27+
vfprintf(stderr, msg, ap);
28+
va_end(ap);
29+
fputc('\n', stderr);
30+
exit(1);
31+
}
32+
33+
static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
34+
bool done = false;
35+
uint8_t *buf = *param_buf;
36+
size_t bufsize = *param_buf_size;
37+
size_t bufpos = 0;
38+
while (!done) {
39+
if (bufpos == bufsize) {
40+
bufsize += 1024;
41+
buf = (uint8_t *)realloc(buf, bufsize);
42+
if (!buf) {
43+
die("failed to allocate memory");
44+
}
45+
}
46+
47+
int c = fgetc(instream);
48+
if (c == EOF) {
49+
die("unexpected EOF client socket");
50+
}
51+
buf[bufpos++] = (uint8_t)c;
52+
if (c == 0) {
53+
// done reading argument
54+
break;
55+
}
56+
}
57+
*param_buf = buf;
58+
*param_buf_size = bufsize;
59+
return strdup((char *)buf);
60+
}
61+
62+
static int read_arguments(int argc, char **argv, FILE *instream) {
63+
int i = 1;
64+
size_t param_buf_size = 0;
65+
uint8_t *param_buf = nullptr;
66+
67+
for (i = 1; i < argc; i++) {
68+
argv[i] = read_argument(&param_buf, &param_buf_size, instream);
69+
}
70+
71+
free(param_buf);
72+
return i;
73+
}
74+
75+
static int serve_model(
76+
gpt_params params,
77+
gpt_vocab vocab,
78+
llama_model model,
79+
int64_t t_load_us,
80+
int64_t t_main_start_us,
81+
int sock_fd)
82+
{
83+
char *response_data;
84+
int argc;
85+
char **argv;
86+
FILE *instream = fdopen(sock_fd, "r");
87+
FILE *outstream = fdopen(sock_fd, "w");
88+
setvbuf(instream, NULL, _IONBF, 0);
89+
90+
// start by reading the parameter count
91+
if (fscanf(instream, "%d\n", &argc) != 1) {
92+
fprintf(outstream, "Error: First line must be character count\n");
93+
fflush(outstream);
94+
return 1;
95+
}
96+
97+
argc += 1; // add one extra argument to emulate the program command line
98+
argv = (char **)malloc(argc * sizeof *argv);
99+
argv[0] = nullptr;
100+
if (read_arguments(argc, argv, instream) != argc) {
101+
fprintf(outstream, "Error: Failed to read arguments\n");
102+
fflush(outstream);
103+
}
104+
105+
if (gpt_params_parse(argc, argv, params) == false) {
106+
fprintf(outstream, "Error: Failed to parse parameters\n");
107+
fflush(outstream);
108+
return 1;
109+
}
110+
111+
for (int i = 1; i < argc; i++) {
112+
free(argv[i]);
113+
}
114+
free(argv);
115+
116+
__gnu_cxx::stdio_filebuf<char> tcp_filebuf(sock_fd, std::ios::in);
117+
std::istream tcp_is(&tcp_filebuf);
118+
119+
return llama_main(params, vocab, model, t_load_us, t_main_start_us, tcp_is, outstream, outstream);
120+
}
121+
122+
int listen_tcp(
123+
gpt_params params,
124+
gpt_vocab vocab,
125+
llama_model model,
126+
int64_t t_main_start_us,
127+
int64_t t_load_us) {
128+
int listen_fd;
129+
int status;
130+
pid_t child;
131+
struct addrinfo hints;
132+
struct addrinfo *servinfo, *p;
133+
int yes = 1;
134+
135+
memset(&hints, 0, sizeof hints);
136+
hints.ai_family = AF_INET;
137+
hints.ai_socktype = SOCK_STREAM;
138+
hints.ai_flags = AI_PASSIVE;
139+
140+
// This should only ever listen on a loopback address. Access from outside
141+
// should be proxied via nginx or similar software
142+
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
143+
if (status) {
144+
die("getaddrinfo error: %s", gai_strerror(status));
145+
}
146+
147+
// bind to the first addrinfo we can from the getaddrinfo results
148+
for (p = servinfo; p != NULL; p = p->ai_next) {
149+
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
150+
if (listen_fd == -1) {
151+
perror("server: socket");
152+
continue;
153+
}
154+
155+
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes)) {
156+
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
157+
}
158+
159+
if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
160+
break;
161+
}
162+
163+
close(listen_fd);
164+
perror("server: bind");
165+
}
166+
167+
freeaddrinfo(servinfo);
168+
169+
if (p == NULL) {
170+
die("failed to bind: %s", strerror(errno));
171+
}
172+
173+
if (listen(listen_fd, 20)) {
174+
die("listen error: %s", strerror(errno));
175+
}
176+
// Don't track child processes, so ignore SIGCHLD to prevent zombies
177+
signal(SIGCHLD, SIG_IGN);
178+
179+
for (;;) {
180+
struct sockaddr_in client_addr = {0};
181+
socklen_t client_addr_len = 0;
182+
183+
int sock_fd = accept(listen_fd,
184+
(struct sockaddr *)&client_addr,
185+
&client_addr_len);
186+
if (sock_fd < 0) {
187+
fprintf(stderr, "accept error: %s\n", strerror(errno));
188+
break;
189+
}
190+
191+
child = fork();
192+
if (child == 0) {
193+
// close the listen_fd since we won't use it in the child
194+
close(listen_fd);
195+
int ret = serve_model(params, vocab, model, t_main_start_us, t_load_us, sock_fd);
196+
close(sock_fd);
197+
return ret;
198+
} else {
199+
// close the client since we won't use it in the server
200+
close(sock_fd);
201+
sock_fd = 0;
202+
}
203+
}
204+
close(listen_fd);
205+
206+
// ignore SIGTERM since we'll send it to the group
207+
signal(SIGTERM, SIG_IGN);
208+
// tell children to exit
209+
kill(0, SIGTERM);
210+
// wait for children to terminate
211+
wait(&status);
212+
return 0;
213+
}

tcp_server.h

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "utils.h"
4+
#include "llama.h"
5+
6+
int listen_tcp(
7+
gpt_params params,
8+
gpt_vocab vocab,
9+
llama_model model,
10+
int64_t t_main_start_us,
11+
int64_t t_load_us);

utils.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7373
params.antiprompt.push_back(argv[++i]);
7474
} else if (arg == "--ignore-eos") {
7575
params.ignore_eos = true;
76+
#ifndef _WIN32
77+
} else if (arg == "-l" || arg == "--listen") {
78+
params.listen_port = argv[++i];
79+
#endif
7680
} else if (arg == "-h" || arg == "--help") {
7781
gpt_print_usage(argc, argv, params);
7882
exit(0);
@@ -118,6 +122,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
118122
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
119123
fprintf(stderr, " -m FNAME, --model FNAME\n");
120124
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
125+
#ifndef _WIN32
126+
fprintf(stderr, " -l PORT, --listen PORT\n");
127+
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
128+
#endif
121129
fprintf(stderr, "\n");
122130
}
123131

utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ struct gpt_params {
4040
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
4141
bool instruct = false; // instruction mode (used for Alpaca models)
4242
bool ignore_eos = false; // do not stop generating after eos
43+
44+
#ifndef _WIN32
45+
std::string listen_port = ""; // TCP port for when running in server mode
46+
#endif
4347
};
4448

4549
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

0 commit comments

Comments
 (0)