Skip to content

Commit 26ebf83

Browse files
support splits in convert.py
1 parent 928e0b7 commit 26ebf83

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

convert.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,16 @@
4444

4545
DEFAULT_CONCURRENCY = 8
4646

47+
DEFAULT_SPLIT_TENSORS = 128
48+
4749
ADDED_TOKENS_FILE = 'added_tokens.json'
4850
FAST_TOKENIZER_FILE = 'tokenizer.json'
4951

52+
LLM_KV_SPLIT_NO = "split.no"
53+
LLM_KV_SPLIT_COUNT = "split.count"
54+
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
55+
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
56+
5057
#
5158
# data types
5259
#
@@ -1235,7 +1242,49 @@ def write_all(
12351242

12361243
of.close()
12371244

1245+
@staticmethod
1246+
def write_split(
1247+
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
1248+
total_tensors: int, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
1249+
pad_vocab: bool = False, tensors_per_shard: int = DEFAULT_SPLIT_TENSORS, small_first_shard: bool = True,
1250+
) -> None:
1251+
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
12381252

1253+
model_list = list(model.items())
1254+
total_shards = math.ceil(total_tensors / tensors_per_shard) + small_first_shard
1255+
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i+1, total_shards)) for i in range(total_shards)]
1256+
1257+
for i, shard in enumerate(shard_files):
1258+
of = OutputFile(shard, endianess=endianess)
1259+
1260+
if i == 0:
1261+
of.add_meta_arch(params)
1262+
if isinstance(vocab, Vocab):
1263+
of.add_meta_vocab(vocab)
1264+
of.add_meta_special_vocab(svocab)
1265+
else: # NoVocab
1266+
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
1267+
1268+
of.gguf.add_uint16(LLM_KV_SPLIT_NO, i)
1269+
of.gguf.add_uint16(LLM_KV_SPLIT_COUNT, total_shards)
1270+
of.gguf.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, total_tensors)
1271+
1272+
# have the option to write a first shard with only the metadata
1273+
if small_first_shard and i == 0:
1274+
of.write_meta()
1275+
of.close()
1276+
continue
1277+
1278+
stop = min((i + 1 - small_first_shard) * tensors_per_shard, total_tensors)
1279+
shard_models = model_list[(i - small_first_shard) * tensors_per_shard:stop]
1280+
for name, lazy_tensor in shard_models:
1281+
of.add_tensor_info(name, lazy_tensor)
1282+
1283+
of.write_meta()
1284+
of.write_tensor_info()
1285+
of.write_tensor_data(ftype, dict(shard_models), concurrency)
1286+
of.close()
1287+
12391288
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
12401289
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
12411290

@@ -1473,6 +1522,9 @@ def main(args_in: list[str] | None = None) -> None:
14731522
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
14741523
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
14751524
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
1525+
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
1526+
parser.add_argument("--split-max-tensors", type=int, help=f"maximum number of tensors per file when splitting (default: {DEFAULT_SPLIT_TENSORS})", default=DEFAULT_SPLIT_TENSORS)
1527+
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default is to only include metadata)")
14761528

14771529
args = parser.parse_args(args_in)
14781530
if args.no_vocab and args.vocab_only:
@@ -1544,11 +1596,28 @@ def main(args_in: list[str] | None = None) -> None:
15441596
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
15451597

15461598
params.ftype = ftype
1547-
print(f"Writing {outfile}, format {ftype}")
15481599

1549-
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
1600+
if args.split:
1601+
total_tensors = len(model)
1602+
if total_tensors < args.split_max_tensors:
1603+
1604+
print("Model has fewer tensors than the split threshold, not splitting")
1605+
print(f"Writing {outfile}, format {ftype}")
1606+
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
15501607
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
1551-
print(f"Wrote {outfile}")
1608+
else:
1609+
print(f"Writing {outfile} as shards, format {ftype}")
1610+
OutputFile.write_split(outfile, ftype, params, model, vocab, special_vocab, total_tensors,
1611+
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab,
1612+
tensors_per_shard=args.split_max_tensors, small_first_shard=not args.large_first_shard)
1613+
print(f"Wrote {outfile}")
1614+
1615+
else:
1616+
print(f"Writing {outfile}, format {ftype}")
1617+
1618+
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
1619+
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
1620+
print(f"Wrote {outfile}")
15521621

15531622

15541623
if __name__ == '__main__':

0 commit comments

Comments
 (0)