Skip to content

Commit 14442d8

Browse files
committed
split: support in llama_model_loader
1 parent d0d5de4 commit 14442d8

File tree

3 files changed

+136
-67
lines changed

3 files changed

+136
-67
lines changed

examples/gguf-split/gguf-split.cpp

+42-67
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@ enum split_operation : uint8_t {
2222
SPLIT_OP_MERGE,
2323
};
2424

25-
static const char * const LLM_KV_GENERAL_SPLIT_I_SPLIT = "general.split";
26-
static const char * const LLM_KV_GENERAL_SPLIT_N_SPLIT = "general.split_count";
27-
28-
static const int SPLIT_FILENAME_MAX = 256;
29-
30-
static const char * const SPLIT_FILENAME_FORMAT = "%s-%05d-of-%05d.gguf";
25+
static const char * const LLM_KV_GENERAL_SPLIT_I_SPLIT = "split.no"; // @ggerganov: should we make this accessible from outside ?
26+
static const char * const LLM_KV_GENERAL_SPLIT_N_SPLIT = "split.count";
3127

3228
struct split_params {
3329
split_operation operation = SPLIT_OP_SPLIT;
@@ -136,12 +132,6 @@ static void zeros(std::ofstream & file, size_t n) {
136132
}
137133
}
138134

139-
static std::string split_file_name(const std::string & path, int i_split, int n_split) {
140-
char f_split[SPLIT_FILENAME_MAX] = {0};
141-
snprintf(f_split, sizeof(f_split), SPLIT_FILENAME_FORMAT, path.c_str(), i_split + 1, n_split);
142-
return std::string(f_split);
143-
}
144-
145135
struct split_strategy {
146136
const split_params params;
147137
std::ifstream & f_input;
@@ -182,19 +172,20 @@ struct split_strategy {
182172
if (i_split == 0) {
183173
gguf_set_kv(ctx_out, ctx_gguf);
184174
}
185-
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_I_SPLIT, i_split);
186-
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, n_split);
175+
gguf_set_val_u16(ctx_out, LLM_KV_GENERAL_SPLIT_I_SPLIT, i_split);
176+
gguf_set_val_u16(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, n_split);
187177

188178
// populate the original tensors, so we get an initial metadata
189179
for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) {
190180
struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
191181
gguf_add_tensor(ctx_out, meta);
192182
}
193183

194-
auto split_name = split_file_name(params.output, i_split, n_split);
184+
char split_path[4096] = {0};
185+
llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
195186

196-
fprintf(stderr, "%s: %s ...", __func__, split_name.c_str());
197-
fout = std::ofstream(split_name, std::ios::binary);
187+
fprintf(stderr, "%s: %s ...", __func__, split_path);
188+
fout = std::ofstream(split_path, std::ios::binary);
198189
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
199190

200191
auto meta_size = gguf_get_meta_size(ctx_out);
@@ -262,9 +253,13 @@ static void gguf_split(const split_params & split_params) {
262253
}
263254

264255
split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
256+
257+
char first_split_path[4096] = {0};
258+
llama_split_path(first_split_path, sizeof(first_split_path),
259+
split_params.output.c_str(), strategy.i_split, strategy.n_split);
265260
fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n",
266261
__func__, split_params.input.c_str(),
267-
split_file_name(split_params.output, strategy.i_split, strategy.n_split).c_str(),
262+
first_split_path,
268263
split_params.n_split_tensors);
269264

270265
strategy.split_start();
@@ -300,7 +295,9 @@ static void gguf_merge(const split_params & split_params) {
300295
std::vector<ggml_context *> ctx_metas;
301296
std::vector<gguf_context *> ctx_ggufs;
302297

303-
std::string split_prefix;
298+
char split_path[4096] = {0};
299+
strncpy(split_path, split_params.input.c_str(), sizeof(split_path));
300+
char split_prefix[4096] = {0};
304301

305302
// First pass to find KV and tensors metadata
306303
for (int i_split = 0; i_split < n_split; i_split++) {
@@ -311,13 +308,12 @@ static void gguf_merge(const split_params & split_params) {
311308
/*.ctx = */ &ctx_meta,
312309
};
313310

314-
auto split_name = split_params.input;
315311
if (i_split > 0) {
316-
split_name = split_file_name(split_prefix, i_split, n_split);
312+
llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
317313
}
318-
fprintf(stderr, "%s: reading metadata %s ...", __func__, split_name.c_str());
314+
fprintf(stderr, "%s: reading metadata %s ...", __func__, split_path);
319315

320-
auto * ctx_gguf = gguf_init_from_file(split_name.c_str(), params);
316+
auto * ctx_gguf = gguf_init_from_file(split_path, params);
321317
if (!ctx_gguf) {
322318
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
323319
exit(1);
@@ -333,65 +329,43 @@ static void gguf_merge(const split_params & split_params) {
333329
__func__,
334330
LLM_KV_GENERAL_SPLIT_N_SPLIT);
335331
gguf_free(ctx_gguf);
332+
ggml_free(ctx_meta);
336333
gguf_free(ctx_out);
337334
fout.close();
338335
exit(1);
339336
}
340337

341-
n_split = gguf_get_val_u8(ctx_gguf, key_n_split);
338+
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
342339
if (n_split < 1) {
343340
fprintf(stderr,
344341
"\n%s: input file does not contain a valid split count %d\n",
345342
__func__,
346343
n_split);
347344
gguf_free(ctx_gguf);
345+
ggml_free(ctx_meta);
348346
gguf_free(ctx_out);
349347
fout.close();
350348
exit(1);
351349
}
352350

353-
// Do not trigger merge if we try to merge again the output
354-
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, 0);
355-
356-
// Set metadata from the first split
357-
gguf_set_kv(ctx_out, ctx_gguf);
358-
}
359-
360-
// Verify the file naming
361-
{
362-
int i_split_file = 0;
363-
int n_split_file = 0;
364-
const char * i_split_format = "-00000-of-00000.gguf";
365-
366-
if (split_name.size() < strlen(i_split_format)) {
367-
fprintf(stderr, "\n%s: unexpected input file name: %s\n", __func__, split_params.input.c_str());
368-
for (auto * _ctx_gguf : ctx_ggufs) {
369-
gguf_free(_ctx_gguf);
370-
}
371-
gguf_free(ctx_out);
372-
fout.close();
373-
exit(1);
374-
}
375-
376-
split_prefix = split_name.substr(0, split_name.size() - strlen(i_split_format));
377-
378-
const char * split_name_c_str = split_name.c_str();
379-
int n_part = sscanf(&split_name_c_str[0] + split_prefix.size(), "-%d-of-%d", &i_split_file, &n_split_file);
380-
381-
if (n_part != 2 || i_split_file - 1 != i_split || n_split_file != n_split) {
351+
// Verify the file naming and extract split_prefix
352+
if (!llama_split_prefix(split_prefix, split_path, i_split, n_split)) {
382353
fprintf(stderr, "\n%s: unexpected input file name: %s"
383-
" i_split=%d i_split_file=%d"
384-
" n_split=%d n_split_file=%d\n", __func__,
385-
split_params.input.c_str(),
386-
i_split, i_split_file,
387-
n_split, n_split_file);
388-
for (auto * _ctx_gguf : ctx_ggufs) {
389-
gguf_free(_ctx_gguf);
390-
}
354+
" i_split=%d"
355+
" n_split=%d\n", __func__,
356+
split_path, i_split, n_split);
357+
gguf_free(ctx_gguf);
358+
ggml_free(ctx_meta);
391359
gguf_free(ctx_out);
392360
fout.close();
393361
exit(1);
394362
}
363+
364+
// Do not trigger merge if we try to merge again the output
365+
gguf_set_val_u16(ctx_gguf, LLM_KV_GENERAL_SPLIT_N_SPLIT, 0);
366+
+
367+
// Set metadata from the first split
368+
gguf_set_kv(ctx_out, ctx_gguf);
395369
}
396370

397371
auto n_tensors = gguf_get_n_tensors(ctx_gguf);
@@ -413,18 +387,19 @@ static void gguf_merge(const split_params & split_params) {
413387

414388
// Write tensors data
415389
for (int i_split = 0; i_split < n_split; i_split++) {
416-
auto split_name = split_file_name(split_prefix, i_split, n_split);
417-
std::ifstream f_input(split_name.c_str(), std::ios::binary);
390+
llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
391+
std::ifstream f_input(split_path, std::ios::binary);
418392
if (!f_input.is_open()) {
419-
fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_name.c_str());
420-
for (auto * _ctx_gguf : ctx_ggufs) {
421-
gguf_free(_ctx_gguf);
393+
fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_path);
394+
for (uint32_t i = 0; i < ctx_ggufs.size(); i++) {
395+
gguf_free(ctx_ggufs[i]);
396+
ggml_free(ctx_metas[i]);
422397
}
423398
gguf_free(ctx_out);
424399
fout.close();
425400
exit(1);
426401
}
427-
fprintf(stderr, "%s: writing tensors %s ...", __func__, split_name.c_str());
402+
fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path);
428403

429404
auto * ctx_gguf = ctx_ggufs[i_split];
430405
auto * ctx_meta = ctx_metas[i_split];

llama.cpp

+84
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ enum llm_kv {
290290
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
291291
LLM_KV_ROPE_SCALING_FINETUNED,
292292

293+
LLM_KV_SPLIT_NO,
294+
LLM_KV_SPLIT_COUNT,
295+
293296
LLM_KV_SSM_INNER_SIZE,
294297
LLM_KV_SSM_CONV_KERNEL,
295298
LLM_KV_SSM_STATE_SIZE,
@@ -355,6 +358,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
355358
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
356359
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
357360

361+
{ LLM_KV_SPLIT_NO, "split.no" },
362+
{ LLM_KV_SPLIT_COUNT, "split.count" },
363+
358364
{ LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
359365
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
360366
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
@@ -2797,6 +2803,9 @@ struct llama_model_loader {
27972803
int n_tensors = 0;
27982804
int n_created = 0;
27992805

2806+
uint16_t n_split = 0;
2807+
std::vector<uint16_t> split_tensor_offsets = {0};
2808+
28002809
int64_t n_elements = 0;
28012810
size_t n_bytes = 0;
28022811

@@ -2840,6 +2849,55 @@ struct llama_model_loader {
28402849
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
28412850
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
28422851

2852+
get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
2853+
if (n_split > 0) {
2854+
uint16_t i_split = 0;
2855+
get_key(llm_kv(LLM_KV_SPLIT_NO), i_split);
2856+
if (i_split != 0) {
2857+
throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", i_split));
2858+
}
2859+
char split_prefix[4096] = {0};
2860+
int n_split_prefix = llama_split_prefix(split_prefix, fname.c_str(), i_split, n_split);
2861+
if (!n_split_prefix) {
2862+
throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
2863+
}
2864+
2865+
if (trace > 0) {
2866+
LLAMA_LOG_INFO("%s: loading additional %d GGUFs split\n",
2867+
__func__, n_split);
2868+
}
2869+
2870+
auto split_n_tensors = gguf_get_n_tensors(ctx_gguf);
2871+
for (i_split = 1; i_split < n_split; i_split++) {
2872+
char split_path[4096] = {0};
2873+
llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
2874+
2875+
struct ggml_context * split_ctx_meta = NULL;
2876+
struct gguf_init_params split_params = {
2877+
/*.no_alloc = */ true,
2878+
/*.ctx = */ &split_ctx_meta,
2879+
};
2880+
auto * split_ctx_gguf = gguf_init_from_file(split_path, split_params);
2881+
if (!split_ctx_gguf) {
2882+
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname.c_str()));
2883+
}
2884+
2885+
split_tensor_offsets.push_back(split_n_tensors);
2886+
split_n_tensors = gguf_get_n_tensors(split_ctx_gguf);
2887+
for (int i_tensor = 0; i_tensor < split_n_tensors; i_tensor++) {
2888+
const char * t_name = gguf_get_tensor_name(split_ctx_gguf, i_tensor);
2889+
struct ggml_tensor * t = ggml_get_tensor(split_ctx_meta, t_name);
2890+
gguf_add_tensor(ctx_gguf, t);
2891+
}
2892+
2893+
gguf_free(split_ctx_gguf);
2894+
ggml_free(split_ctx_meta);
2895+
}
2896+
2897+
LLAMA_LOG_INFO("%s: additional %d GGUFs split metadata loaded.\n",
2898+
__func__, n_split);
2899+
}
2900+
28432901
n_kv = gguf_get_n_kv(ctx_gguf);
28442902
n_tensors = gguf_get_n_tensors(ctx_gguf);
28452903

@@ -14648,6 +14706,32 @@ LLAMA_API int32_t llama_chat_apply_template(
1464814706
return res;
1464914707
}
1465014708

14709+
LLAMA_API int llama_split_path(char * split_path, int maxlen, const char * path_prefix, int split_no, int split_count) {
14710+
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
14711+
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
14712+
return strlen(split_path);
14713+
}
14714+
return 0;
14715+
}
14716+
14717+
LLAMA_API int llama_split_prefix(char * dest, const char * split_path, int split_no, int split_count) {
14718+
char split_prefix[PATH_MAX] = {0};
14719+
int split_no_file = 0;
14720+
int split_count_file = 0;
14721+
const char * split_format = "-00000-of-00000.gguf";
14722+
14723+
if (strlen(split_path) > strlen(split_format) + 1) {
14724+
strncpy(split_prefix, split_path, strlen(split_path) - strlen(split_format));
14725+
14726+
int n = sscanf(&split_path[0] + strlen(split_prefix), "-%d-of-%d", &split_no_file, &split_count_file);
14727+
if (n == 2 && split_no_file - 1 == split_no && split_count_file == split_count) {
14728+
strcpy(dest, split_prefix);
14729+
return strlen(split_prefix);
14730+
}
14731+
}
14732+
return 0;
14733+
}
14734+
1465114735
struct llama_timings llama_get_timings(struct llama_context * ctx) {
1465214736
struct llama_timings result = {
1465314737
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

llama.h

+10
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,16 @@ extern "C" {
960960
int32_t n_past,
961961
int32_t n_predict);
962962

963+
/// @details Build a split GGUF final path for this chunk.
964+
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
965+
// Returns the split_path length.
966+
LLAMA_API int llama_split_path(char * split_path, int maxlen, const char * path_prefix, int split_no, int split_count);
967+
968+
/// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
969+
/// llama_split_prefix(split_prefix, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
970+
// Returns the split_prefix length.
971+
LLAMA_API int llama_split_prefix(char * split_prefix, const char * split_path, int split_no, int split_count);
972+
963973
// Performance information
964974
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
965975

0 commit comments

Comments
 (0)