Skip to content

Commit 900d280

Browse files
committed
add protocol mode, an irc-like network encoding
1 parent a81da09 commit 900d280

File tree

2 files changed

+113
-22
lines changed

2 files changed

+113
-22
lines changed

llama.cpp

+107-19
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,29 @@ void sigint_handler(int signo) {
829829
}
830830
#endif
831831

832+
833+
std::string escapeString(std::string stdstr) {
834+
const char* str = stdstr.c_str();
835+
std::string escapedStr;
836+
for (const char* c = str; *c != '\0'; ++c) {
837+
switch (*c) {
838+
case '\a': escapedStr += "\\a"; break;
839+
case '\b': escapedStr += "\\b"; break;
840+
case '\f': escapedStr += "\\f"; break;
841+
case '\n': escapedStr += "\\n"; break;
842+
case '\r': escapedStr += "\\r"; break;
843+
case '\t': escapedStr += "\\t"; break;
844+
case '\v': escapedStr += "\\v"; break;
845+
case '\\': escapedStr += "\\\\"; break;
846+
case '\"': escapedStr += "\\\""; break;
847+
case '\'': escapedStr += "\\\'"; break;
848+
default: escapedStr += *c; break;
849+
}
850+
}
851+
//std::cout << "test string" << escapedStr << std::endl;
852+
return escapedStr;
853+
}
854+
832855
int llama_main(
833856
gpt_params params,
834857
llama_vocab vocab,
@@ -842,8 +865,12 @@ int llama_main(
842865
if (params.seed < 0) {
843866
params.seed = time(NULL);
844867
}
845-
846-
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
868+
if(params.protocol_mode) {
869+
fprintf(outstream, "%s", "HELO\n");
870+
fprintf(outstream, "KV seed=%d\n", params.seed);
871+
} else {
872+
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
873+
}
847874

848875
std::mt19937 rng(params.seed);
849876
if (params.random_prompt) {
@@ -891,13 +918,24 @@ int llama_main(
891918
params.interactive = true;
892919
}
893920

894-
fprintf(errstream, "\n");
895-
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
896-
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
921+
if(params.protocol_mode) {
922+
fprintf(outstream, "PROMPT %s\n", escapeString(params.prompt).c_str());
923+
fprintf(outstream, "KV prompt_tokens=%zu\n",embd_inp.size());
924+
} else {
925+
fprintf(errstream, "\n");
926+
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
927+
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
928+
}
897929
for (int i = 0; i < (int) embd_inp.size(); i++) {
898-
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
930+
if(params.protocol_mode) {
931+
fprintf(outstream, "DEBUG %d -> '%s'\n", embd_inp[i], escapeString(vocab.id_to_token.at(embd_inp[i])).c_str());
932+
} else {
933+
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
934+
}
935+
}
936+
if(!params.protocol_mode) {
937+
fprintf(errstream, "\n");
899938
}
900-
fprintf(errstream, "\n");
901939
if (params.interactive) {
902940
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
903941
struct sigaction sigint_action;
@@ -909,16 +947,32 @@ int llama_main(
909947
signal(SIGINT, sigint_handler);
910948
#endif
911949

912-
fprintf(errstream, "%s: interactive mode on.\n", __func__);
950+
if(params.protocol_mode) {
951+
fprintf(outstream, "KV interactive_mode=true\n");
952+
} else {
953+
fprintf(errstream, "%s: interactive mode on.\n", __func__);
954+
}
913955

914956
if(params.antiprompt.size()) {
915957
for (auto antiprompt : params.antiprompt) {
916-
fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str());
958+
if(params.protocol_mode) {
959+
fprintf(outstream, "KV reverse_prompt=\"%s\"\n", escapeString(antiprompt).c_str());
960+
} else {
961+
fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str());
962+
}
917963
}
918964
}
919965
}
920-
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
921-
fprintf(errstream, "\n\n");
966+
if(params.protocol_mode) {
967+
fprintf(errstream, "KV temp=%f\n", params.temp);
968+
fprintf(errstream, "KV top_k=%d\n", params.top_k);
969+
fprintf(errstream, "KV top_p=%f\n", params.top_p);
970+
fprintf(errstream, "KV repeat_last_n=%i\n", params.repeat_last_n);
971+
fprintf(errstream, "KV repeat_penalty=%f\n", params.repeat_penalty);
972+
} else {
973+
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
974+
fprintf(errstream, "\n\n");
975+
}
922976

923977
std::vector<llama_vocab::id> embd;
924978

@@ -927,12 +981,14 @@ int llama_main(
927981
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
928982

929983
if (params.interactive) {
930-
fprintf(errstream, "== Running in interactive mode. ==\n"
984+
if(!params.protocol_mode) {
985+
fprintf(errstream, "== Running in interactive mode. ==\n"
931986
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
932-
" - Press Ctrl+C to interject at any time.\n"
987+
" - Press Ctrl+C to interject at any time.\n"
933988
#endif
934-
" - Press Return to return control to LLaMa.\n"
935-
" - If you want to submit another line, end your input in '\\'.\n\n");
989+
" - Press Return to return control to LLaMa.\n"
990+
" - If you want to submit another line, end your input in '\\'.\n\n");
991+
}
936992
is_interacting = true;
937993
}
938994

@@ -955,12 +1011,19 @@ int llama_main(
9551011
}
9561012

9571013
while (remaining_tokens > 0 || params.interactive) {
1014+
if(params.protocol_mode && !params.interactive) {
1015+
fprintf(outstream, "KV remaining_tokens=%d\n", remaining_tokens);
1016+
}
9581017
// predict
9591018
if (embd.size() > 0) {
9601019
const int64_t t_start_us = ggml_time_us();
9611020

9621021
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
963-
fprintf(errstream, "Failed to predict\n");
1022+
if(params.protocol_mode) {
1023+
fprintf(outstream, "FATAL Error: Failed to predict\n");
1024+
} else {
1025+
fprintf(errstream, "Failed to predict\n");
1026+
}
9641027
return 1;
9651028
}
9661029

@@ -1020,8 +1083,16 @@ int llama_main(
10201083

10211084
// display text
10221085
if (!input_noecho) {
1086+
if(params.protocol_mode) {
1087+
fprintf(outstream, "OUTPUT ");
1088+
}
10231089
for (auto id : embd) {
1024-
fprintf(outstream, "%s", vocab.id_to_token[id].c_str());
1090+
fprintf(outstream, "%s", params.protocol_mode ?
1091+
escapeString(vocab.id_to_token[id]).c_str() :
1092+
vocab.id_to_token[id].c_str());
1093+
}
1094+
if(params.protocol_mode) {
1095+
fprintf(outstream, "\n");
10251096
}
10261097
fflush(outstream);
10271098
}
@@ -1047,11 +1118,17 @@ int llama_main(
10471118
}
10481119
}
10491120
if (is_interacting) {
1121+
if(params.protocol_mode) {
1122+
fprintf(outstream, "KV awaiting_prompt=true\n");
1123+
fflush(outstream);
1124+
}
10501125
if (params.instruct) {
10511126
input_consumed = embd_inp.size();
10521127
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
10531128

1054-
fprintf(outstream, "\n> ");
1129+
if(!params.protocol_mode) {
1130+
fprintf(outstream, "\n> ");
1131+
}
10551132
}
10561133

10571134
// currently being interactive
@@ -1068,6 +1145,7 @@ int llama_main(
10681145
}
10691146
buffer += line + '\n'; // Append the line to the result
10701147
} while (another_line);
1148+
fprintf(outstream, "PROMPT %s\n", escapeString(line).c_str());
10711149
if (params.use_color) fprintf(outstream, ANSI_COLOR_RESET);
10721150

10731151
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
@@ -1080,6 +1158,10 @@ int llama_main(
10801158
remaining_tokens -= line_inp.size();
10811159

10821160
input_noecho = true; // do not echo this again
1161+
if(params.protocol_mode) {
1162+
fprintf(outstream, "KV awaiting_prompt=false\n");
1163+
fflush(outstream);
1164+
}
10831165
}
10841166
is_interacting = false;
10851167
}
@@ -1089,7 +1171,13 @@ int llama_main(
10891171
if (params.interactive) {
10901172
is_interacting = true;
10911173
} else {
1092-
fprintf(errstream, " [end of text]\n");
1174+
if(params.protocol_mode) {
1175+
fprintf(outstream, "END_OF_TEXT\n");
1176+
fflush(outstream);
1177+
} else {
1178+
fprintf(errstream, " [end of text]\n");
1179+
fflush(errstream);
1180+
}
10931181
break;
10941182
}
10951183
}

tcp_server.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static int serve_model(
122122

123123
// start by reading the parameter count
124124
if (fscanf(instream, "%d\n", &argc) != 1) {
125-
fprintf(outstream, "Error: First line must be character count\n");
125+
fprintf(outstream, "FATAL Error: First line must be character count\n");
126126
fflush(outstream);
127127
return 1;
128128
}
@@ -131,12 +131,12 @@ static int serve_model(
131131
argv = (char **)malloc(argc * sizeof *argv);
132132
argv[0] = nullptr;
133133
if (read_arguments(argc, argv, instream) != argc) {
134-
fprintf(outstream, "Error: Failed to read arguments\n");
134+
fprintf(outstream, "FATAL Error: Failed to read arguments\n");
135135
fflush(outstream);
136136
}
137137

138138
if (gpt_params_parse(argc, argv, params) == false) {
139-
fprintf(outstream, "Error: Failed to parse parameters\n");
139+
fprintf(outstream, "FATAL Error: Failed to parse parameters\n");
140140
fflush(outstream);
141141
return 1;
142142
}
@@ -148,6 +148,9 @@ static int serve_model(
148148

149149
PosixStream tcp_is(sock_fd);
150150

151+
params.protocol_mode = true;
152+
params.use_color = false;
153+
151154
return llama_main(params, vocab, model, t_load_us, t_main_start_us, tcp_is, outstream, outstream);
152155
}
153156

0 commit comments

Comments
 (0)