@@ -11,10 +11,11 @@ int main(int argc, char ** argv) {
11
11
gpt_params params;
12
12
13
13
if (argc == 1 || argv[1 ][0 ] == ' -' ) {
14
- printf (" usage: %s MODEL_PATH [IS_PP_SHARED] [NGL]\n " , argv[0 ]);
14
+ printf (" usage: %s MODEL_PATH [N_KV_MAX] [ IS_PP_SHARED] [NGL]\n " , argv[0 ]);
15
15
return 1 ;
16
16
}
17
17
18
+ int n_kv_max = 2048 ;
18
19
int is_pp_shared = 0 ;
19
20
int n_gpu_layers = 0 ;
20
21
@@ -23,18 +24,20 @@ int main(int argc, char ** argv) {
23
24
std::vector<int > n_pl = { 1 , 2 , 4 , 8 , 16 , 32 , };
24
25
// std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, };
25
26
26
- const int32_t n_ctx_max = 16 *1024 ;
27
-
28
27
if (argc >= 2 ) {
29
28
params.model = argv[1 ];
30
29
}
31
30
32
31
if (argc >= 3 ) {
33
- is_pp_shared = std::atoi (argv[2 ]);
32
+ n_kv_max = std::atoi (argv[2 ]);
34
33
}
35
34
36
35
if (argc >= 4 ) {
37
- n_gpu_layers = std::atoi (argv[3 ]);
36
+ is_pp_shared = std::atoi (argv[3 ]);
37
+ }
38
+
39
+ if (argc >= 5 ) {
40
+ n_gpu_layers = std::atoi (argv[4 ]);
38
41
}
39
42
40
43
// init LLM
@@ -56,8 +59,8 @@ int main(int argc, char ** argv) {
56
59
57
60
llama_context_params ctx_params = llama_context_default_params ();
58
61
59
- ctx_params.seed = 1234 ;
60
- ctx_params.n_ctx = n_ctx_max ;
62
+ ctx_params.seed = 1234 ;
63
+ ctx_params.n_ctx = n_kv_max ;
61
64
ctx_params.n_batch = 512 ;
62
65
ctx_params.n_threads = params.n_threads ;
63
66
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch ;
@@ -69,7 +72,7 @@ int main(int argc, char ** argv) {
69
72
return 1 ;
70
73
}
71
74
72
- llama_batch batch = llama_batch_init (n_ctx_max , 0 );
75
+ llama_batch batch = llama_batch_init (n_kv_max , 0 );
73
76
74
77
// decode in batches of ctx_params.n_batch tokens
75
78
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
88
91
89
92
const int ret = llama_decode (ctx, batch_view);
90
93
if (ret != 0 ) {
91
- LOG_TEE (" %s : failed to decode the batch, n_batch = %d, ret = %d\n " , __func__ , n_batch, ret);
94
+ LOG_TEE (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
92
95
return false ;
93
96
}
94
97
}
@@ -117,7 +120,7 @@ int main(int argc, char ** argv) {
117
120
118
121
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
119
122
120
- if (n_ctx_req > n_ctx_max ) {
123
+ if (n_ctx_req > n_kv_max ) {
121
124
continue ;
122
125
}
123
126
0 commit comments