-
Notifications
You must be signed in to change notification settings - Fork 523
/
Copy pathload_gguf.py
103 lines (79 loc) · 2.85 KB
/
load_gguf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import gguf
from gguf import GGUFValueType, ReaderTensor
@dataclass
class AttentionArgs:
head_count: int
head_count_kv: int
layer_norm_rms_epsilon: float
@dataclass
class RopeArgs:
freq_base: float
@dataclass
class GGUFModelArgs:
arch: str
embedding_length: int
block_count: int
feed_forward_length: int
vocab_size: int
attention: AttentionArgs
rope: RopeArgs
@dataclass
class GGUFWeights:
tensors: list[ReaderTensor]
def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
metadata: dict[str, Any] = {}
for idx, field in enumerate(reader.fields.values()):
val = None
if field.types[:1] == [GGUFValueType.ARRAY]:
itype = field.types[-1]
if itype == GGUFValueType.STRING:
val = [
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
]
else:
val = [pv for idx in field.data for pv in field.parts[idx].tolist()]
elif field.types[0] == GGUFValueType.STRING:
val = str(bytes(field.parts[-1]), encoding="utf-8")
else:
val = field.parts[-1].tolist()[0]
metadata[field.name] = val
return metadata
def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
arch = metadata["general.architecture"]
return GGUFModelArgs(
arch=arch,
embedding_length=metadata[f"{arch}.embedding_length"],
block_count=metadata[f"{arch}.block_count"],
feed_forward_length=metadata[f"{arch}.feed_forward_length"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
attention=AttentionArgs(
head_count=metadata[f"{arch}.attention.head_count"],
head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
),
rope=RopeArgs(
# default value from llama2 model definition
freq_base=metadata.get(f"{arch}.rope.freq_base", 1e4),
),
)
def load_file(gguf_file: str) -> (GGUFModelArgs, GGUFWeights):
"""
Load a GGUF file and return the model arguments and weights.
"""
if not Path(gguf_file).is_file():
raise ValueError(f"Could not find file {gguf_file}")
reader = gguf.GGUFReader(gguf_file, "r")
# Step 1: Build GGUFModelArgs
metadata = _get_metadata(reader)
model_args = _build_model_args(metadata)
# Step 2: Build GGUFWeights
gguf_weights = GGUFWeights(tensors=reader.tensors)
return (model_args, gguf_weights)