Skip to content

Commit 0492ebd

Browse files
committed
Remove direct access to std streams from llama_main
The goal is to allow running llama_main while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha <[email protected]>
1 parent 536df15 commit 0492ebd

File tree

3 files changed

+41
-35
lines changed

3 files changed

+41
-35
lines changed

llama.cpp

+36-33
Original file line numberDiff line numberDiff line change
@@ -716,13 +716,16 @@ int llama_main(
716716
gpt_vocab vocab,
717717
llama_model model,
718718
int64_t t_load_us,
719-
int64_t t_main_start_us) {
719+
int64_t t_main_start_us,
720+
FILE *instream,
721+
FILE *outstream,
722+
FILE *errstream) {
720723

721724
if (params.seed < 0) {
722725
params.seed = time(NULL);
723726
}
724727

725-
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
728+
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
726729

727730
std::mt19937 rng(params.seed);
728731
if (params.random_prompt) {
@@ -764,13 +767,13 @@ int llama_main(
764767
params.interactive = true;
765768
}
766769

767-
fprintf(stderr, "\n");
768-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
769-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
770+
fprintf(errstream, "\n");
771+
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
772+
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
770773
for (int i = 0; i < (int) embd_inp.size(); i++) {
771-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
774+
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
772775
}
773-
fprintf(stderr, "\n");
776+
fprintf(errstream, "\n");
774777
if (params.interactive) {
775778
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
776779
struct sigaction sigint_action;
@@ -782,19 +785,19 @@ int llama_main(
782785
signal(SIGINT, sigint_handler);
783786
#endif
784787

785-
fprintf(stderr, "%s: interactive mode on.\n", __func__);
788+
fprintf(errstream, "%s: interactive mode on.\n", __func__);
786789

787790
if (antiprompt_inp.size()) {
788-
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
789-
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
791+
fprintf(errstream, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
792+
fprintf(errstream, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
790793
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
791-
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
794+
fprintf(errstream, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
792795
}
793-
fprintf(stderr, "\n");
796+
fprintf(errstream, "\n");
794797
}
795798
}
796-
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
797-
fprintf(stderr, "\n\n");
799+
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
800+
fprintf(errstream, "\n\n");
798801

799802
std::vector<gpt_vocab::id> embd;
800803

@@ -807,7 +810,7 @@ int llama_main(
807810
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
808811

809812
if (params.interactive) {
810-
fprintf(stderr, "== Running in interactive mode. ==\n"
813+
fprintf(errstream, "== Running in interactive mode. ==\n"
811814
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
812815
" - Press Ctrl+C to interject at any time.\n"
813816
#endif
@@ -823,7 +826,7 @@ int llama_main(
823826

824827
// set the color for the prompt which will be output initially
825828
if (params.use_color) {
826-
printf(ANSI_COLOR_YELLOW);
829+
fprintf(outstream, ANSI_COLOR_YELLOW);
827830
}
828831

829832
while (remaining_tokens > 0 || params.interactive) {
@@ -832,7 +835,7 @@ int llama_main(
832835
const int64_t t_start_us = ggml_time_us();
833836

834837
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
835-
fprintf(stderr, "Failed to predict\n");
838+
fprintf(errstream, "Failed to predict\n");
836839
return 1;
837840
}
838841

@@ -891,16 +894,16 @@ int llama_main(
891894

892895
// reset color to default if we there is no pending user input
893896
if (!input_noecho && params.use_color && (int) embd_inp.size() == input_consumed) {
894-
printf(ANSI_COLOR_RESET);
897+
fprintf(outstream, ANSI_COLOR_RESET);
895898
}
896899
}
897900

898901
// display text
899902
if (!input_noecho) {
900903
for (auto id : embd) {
901-
printf("%s", vocab.id_to_token[id].c_str());
904+
fprintf(outstream, "%s", vocab.id_to_token[id].c_str());
902905
}
903-
fflush(stdout);
906+
fflush(outstream);
904907
}
905908

906909
// in interactive mode, and not currently processing queued inputs;
@@ -922,16 +925,16 @@ int llama_main(
922925
// currently being interactive
923926
bool another_line = true;
924927
while (another_line) {
925-
fflush(stdout);
928+
fflush(outstream);
926929
char buf[256] = {0};
927930
int n_read;
928-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
929-
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
931+
if (params.use_color) fprintf(outstream, ANSI_BOLD ANSI_COLOR_GREEN);
932+
if (fscanf(instream, "%255[^\n]%n%*c", buf, &n_read) <= 0) {
930933
// presumable empty line, consume the newline
931-
std::ignore = scanf("%*c");
934+
std::ignore = fscanf(instream, "%*c");
932935
n_read=0;
933936
}
934-
if (params.use_color) printf(ANSI_COLOR_RESET);
937+
if (params.use_color) fprintf(outstream, ANSI_COLOR_RESET);
935938

936939
if (n_read > 0 && buf[n_read-1]=='\\') {
937940
another_line = true;
@@ -964,7 +967,7 @@ int llama_main(
964967
if (params.interactive) {
965968
is_interacting = true;
966969
} else {
967-
fprintf(stderr, " [end of text]\n");
970+
fprintf(errstream, " [end of text]\n");
968971
break;
969972
}
970973
}
@@ -984,18 +987,18 @@ int llama_main(
984987
{
985988
const int64_t t_main_end_us = ggml_time_us();
986989

987-
fprintf(stderr, "\n\n");
988-
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
989-
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
990-
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
991-
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
992-
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
990+
fprintf(errstream, "\n\n");
991+
fprintf(errstream, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
992+
fprintf(errstream, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
993+
fprintf(errstream, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
994+
fprintf(errstream, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
995+
fprintf(errstream, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
993996
}
994997

995998
ggml_free(model.ctx);
996999

9971000
if (params.use_color) {
998-
printf(ANSI_COLOR_RESET);
1001+
fprintf(outstream, ANSI_COLOR_RESET);
9991002
}
10001003

10011004
return 0;

llama.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,8 @@ int llama_main(
6464
gpt_vocab vocab,
6565
llama_model model,
6666
int64_t t_load_us,
67-
int64_t t_main_start_us);
67+
int64_t t_main_start_us,
68+
FILE *instream,
69+
FILE *outstream,
70+
FILE *errstream);
6871
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type);

main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ int main(int argc, char ** argv) {
6363
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
6464
}
6565

66-
return llama_main(params, vocab, model, t_main_start_us, t_load_us);
66+
return llama_main(params, vocab, model, t_main_start_us, t_load_us, stdin, stdout, stderr);
6767
}

0 commit comments

Comments
 (0)