@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
35
35
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n , 0 );
36
36
37
37
// init
38
- auto ctx = llama_init_from_file (params.model .c_str (), lparams);
38
+ auto model = llama_load_model_from_file (params.model .c_str (), lparams);
39
+ if (model == nullptr ) {
40
+ return 1 ;
41
+ }
42
+ auto ctx = llama_new_context_with_model (model, lparams);
43
+ if (ctx == nullptr ) {
44
+ llama_free_model (model);
45
+ return 1 ;
46
+ }
39
47
auto tokens = std::vector<llama_token>(params.n_ctx );
40
48
auto n_prompt_tokens = llama_tokenize (ctx, params.prompt .c_str (), tokens.data (), int (tokens.size ()), true );
41
49
42
50
if (n_prompt_tokens < 1 ) {
43
51
fprintf (stderr, " %s : failed to tokenize prompt\n " , __func__);
52
+ llama_free (ctx);
53
+ llama_free_model (model);
44
54
return 1 ;
45
55
}
46
56
@@ -84,30 +94,36 @@ int main(int argc, char ** argv) {
84
94
printf (" %s" , next_token_str);
85
95
if (llama_eval (ctx, &next_token, 1 , n_past, params.n_threads )) {
86
96
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
97
+ llama_free (ctx);
98
+ llama_free_model (model);
87
99
return 1 ;
88
100
}
89
101
n_past += 1 ;
90
102
}
91
103
92
104
printf (" \n\n " );
93
105
94
- // free old model
106
+ // free old context
95
107
llama_free (ctx);
96
108
97
- // load new model
98
- auto ctx2 = llama_init_from_file (params. model . c_str () , lparams);
109
+ // make new context
110
+ auto ctx2 = llama_new_context_with_model ( model, lparams);
99
111
100
112
// Load state (rng, logits, embedding and kv_cache) from file
101
113
{
102
114
FILE *fp_read = fopen (" dump_state.bin" , " rb" );
103
115
if (state_size != llama_get_state_size (ctx2)) {
104
116
fprintf (stderr, " \n %s : failed to validate state size\n " , __func__);
117
+ llama_free (ctx2);
118
+ llama_free_model (model);
105
119
return 1 ;
106
120
}
107
121
108
122
const size_t ret = fread (state_mem, 1 , state_size, fp_read);
109
123
if (ret != state_size) {
110
124
fprintf (stderr, " \n %s : failed to read state\n " , __func__);
125
+ llama_free (ctx2);
126
+ llama_free_model (model);
111
127
return 1 ;
112
128
}
113
129
@@ -138,12 +154,17 @@ int main(int argc, char ** argv) {
138
154
printf (" %s" , next_token_str);
139
155
if (llama_eval (ctx2, &next_token, 1 , n_past, params.n_threads )) {
140
156
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
157
+ llama_free (ctx2);
158
+ llama_free_model (model);
141
159
return 1 ;
142
160
}
143
161
n_past += 1 ;
144
162
}
145
163
146
164
printf (" \n\n " );
147
165
166
+ llama_free (ctx2);
167
+ llama_free_model (model);
168
+
148
169
return 0 ;
149
170
}
0 commit comments