@@ -717,13 +717,16 @@ int llama_main(
717
717
gpt_vocab vocab,
718
718
llama_model model,
719
719
int64_t t_load_us,
720
- int64_t t_main_start_us) {
720
+ int64_t t_main_start_us,
721
+ std::istream & instream,
722
+ FILE *outstream,
723
+ FILE *errstream) {
721
724
722
725
if (params.seed < 0 ) {
723
726
params.seed = time (NULL );
724
727
}
725
728
726
- fprintf (stderr , " %s: seed = %d\n " , __func__, params.seed );
729
+ fprintf (errstream , " %s: seed = %d\n " , __func__, params.seed );
727
730
728
731
std::mt19937 rng (params.seed );
729
732
if (params.random_prompt ) {
@@ -769,13 +772,13 @@ int llama_main(
769
772
params.interactive = true ;
770
773
}
771
774
772
- fprintf (stderr , " \n " );
773
- fprintf (stderr , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
774
- fprintf (stderr , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
775
+ fprintf (errstream , " \n " );
776
+ fprintf (errstream , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
777
+ fprintf (errstream , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
775
778
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
776
- fprintf (stderr , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
779
+ fprintf (errstream , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
777
780
}
778
- fprintf (stderr , " \n " );
781
+ fprintf (errstream , " \n " );
779
782
if (params.interactive ) {
780
783
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
781
784
struct sigaction sigint_action;
@@ -787,22 +790,22 @@ int llama_main(
787
790
signal (SIGINT, sigint_handler);
788
791
#endif
789
792
790
- fprintf (stderr , " %s: interactive mode on.\n " , __func__);
793
+ fprintf (errstream , " %s: interactive mode on.\n " , __func__);
791
794
792
795
if (antipromptv_inp.size ()) {
793
796
for (size_t apindex = 0 ; apindex < antipromptv_inp.size (); ++apindex) {
794
797
auto antiprompt_inp = antipromptv_inp.at (apindex);
795
- fprintf (stderr , " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
796
- fprintf (stderr , " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
798
+ fprintf (errstream , " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
799
+ fprintf (errstream , " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
797
800
for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
798
- fprintf (stderr , " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
801
+ fprintf (errstream , " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
799
802
}
800
- fprintf (stderr , " \n " );
803
+ fprintf (errstream , " \n " );
801
804
}
802
805
}
803
806
}
804
- fprintf (stderr , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
805
- fprintf (stderr , " \n\n " );
807
+ fprintf (errstream , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
808
+ fprintf (errstream , " \n\n " );
806
809
807
810
std::vector<gpt_vocab::id> embd;
808
811
@@ -815,7 +818,7 @@ int llama_main(
815
818
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
816
819
817
820
if (params.interactive ) {
818
- fprintf (stderr , " == Running in interactive mode. ==\n "
821
+ fprintf (errstream , " == Running in interactive mode. ==\n "
819
822
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
820
823
" - Press Ctrl+C to interject at any time.\n "
821
824
#endif
@@ -831,7 +834,7 @@ int llama_main(
831
834
832
835
// set the color for the prompt which will be output initially
833
836
if (params.use_color ) {
834
- printf ( ANSI_COLOR_YELLOW);
837
+ fprintf (outstream, ANSI_COLOR_YELLOW);
835
838
}
836
839
837
840
while (remaining_tokens > 0 || params.interactive ) {
@@ -840,7 +843,7 @@ int llama_main(
840
843
const int64_t t_start_us = ggml_time_us ();
841
844
842
845
if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
843
- fprintf (stderr , " Failed to predict\n " );
846
+ fprintf (errstream , " Failed to predict\n " );
844
847
return 1 ;
845
848
}
846
849
@@ -901,9 +904,9 @@ int llama_main(
901
904
// display text
902
905
if (!input_noecho) {
903
906
for (auto id : embd) {
904
- printf ( " %s" , vocab.id_to_token [id].c_str ());
907
+ fprintf (outstream, " %s" , vocab.id_to_token [id].c_str ());
905
908
}
906
- fflush (stdout );
909
+ fflush (outstream );
907
910
}
908
911
// reset color to default if we there is no pending user input
909
912
if (!input_noecho && params.use_color && (int )embd_inp.size () == input_consumed) {
@@ -935,7 +938,7 @@ int llama_main(
935
938
std::string line;
936
939
bool another_line = true ;
937
940
do {
938
- std::getline (std::cin , line);
941
+ std::getline (instream , line);
939
942
if (line.empty () || line.back () != ' \\ ' ) {
940
943
another_line = false ;
941
944
} else {
@@ -964,7 +967,7 @@ int llama_main(
964
967
if (params.interactive ) {
965
968
is_interacting = true ;
966
969
} else {
967
- fprintf (stderr , " [end of text]\n " );
970
+ fprintf (errstream , " [end of text]\n " );
968
971
break ;
969
972
}
970
973
}
@@ -984,18 +987,18 @@ int llama_main(
984
987
{
985
988
const int64_t t_main_end_us = ggml_time_us ();
986
989
987
- fprintf (stderr , " \n\n " );
988
- fprintf (stderr , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
989
- fprintf (stderr , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
990
- fprintf (stderr , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
991
- fprintf (stderr , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
992
- fprintf (stderr , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
990
+ fprintf (errstream , " \n\n " );
991
+ fprintf (errstream , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
992
+ fprintf (errstream , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
993
+ fprintf (errstream , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
994
+ fprintf (errstream , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
995
+ fprintf (errstream , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
993
996
}
994
997
995
998
ggml_free (model.ctx );
996
999
997
1000
if (params.use_color ) {
998
- printf ( ANSI_COLOR_RESET);
1001
+ fprintf (outstream, ANSI_COLOR_RESET);
999
1002
}
1000
1003
1001
1004
return 0 ;
0 commit comments