|
8 | 8 | import gguf
|
9 | 9 | from sentencepiece import SentencePieceProcessor # type: ignore[import]
|
10 | 10 |
|
| 11 | +try: |
| 12 | + from safetensors import safe_open |
| 13 | +except ImportError: |
| 14 | + print("Please install `safetensors` python package") |
| 15 | + sys.exit(1) |
| 16 | + |
11 | 17 |
|
12 | 18 | def count_model_parts(dir_model: Path) -> int:
|
| 19 | + # get number of model parts |
13 | 20 | num_parts = 0
|
14 | 21 | for filename in os.listdir(dir_model):
|
15 |
| - if filename.startswith("pytorch_model-"): |
| 22 | + if filename.startswith("model-00"): |
16 | 23 | num_parts += 1
|
17 | 24 |
|
18 | 25 | if num_parts > 0:
|
@@ -161,22 +168,22 @@ def parse_args() -> argparse.Namespace:
|
161 | 168 | print("gguf: get tensor metadata")
|
162 | 169 |
|
163 | 170 | if num_parts == 0:
|
164 |
| - part_names = iter(("pytorch_model.bin",)) |
| 171 | + part_names = iter(("model.safetensors",)) |
165 | 172 | else:
|
166 | 173 | part_names = (
|
167 |
| - f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) |
| 174 | + f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1) |
168 | 175 | )
|
169 | 176 |
|
170 | 177 | for part_name in part_names:
|
171 | 178 | if args.vocab_only:
|
172 | 179 | break
|
173 | 180 | print("gguf: loading model part '" + part_name + "'")
|
174 |
| - model_part = torch.load(dir_model / part_name, map_location="cpu") |
| 181 | + model_part = safe_open(dir_model / part_name, framework="pt") |
175 | 182 |
|
176 | 183 | for name in model_part.keys():
|
177 | 184 | if "self_attn.rotary_emb.inv_freq" in name:
|
178 | 185 | continue
|
179 |
| - data = model_part[name] |
| 186 | + data = model_part.get_tensor(name) |
180 | 187 |
|
181 | 188 | old_dtype = data.dtype
|
182 | 189 |
|
|
0 commit comments