Skip to content

Commit fccb147

Browse files
committed
use safetensors for the latest plamo-13b repo
1 parent 074bd14 commit fccb147

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

convert-plamo-hf-to-gguf.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
import gguf
99
from sentencepiece import SentencePieceProcessor # type: ignore[import]
1010

11+
try:
12+
from safetensors import safe_open
13+
except ImportError:
14+
print("Please install `safetensors` python package")
15+
sys.exit(1)
16+
1117

1218
def count_model_parts(dir_model: Path) -> int:
19+
# get number of model parts
1320
num_parts = 0
1421
for filename in os.listdir(dir_model):
15-
if filename.startswith("pytorch_model-"):
22+
if filename.startswith("model-00"):
1623
num_parts += 1
1724

1825
if num_parts > 0:
@@ -161,22 +168,22 @@ def parse_args() -> argparse.Namespace:
161168
print("gguf: get tensor metadata")
162169

163170
if num_parts == 0:
164-
part_names = iter(("pytorch_model.bin",))
171+
part_names = iter(("model.safetensors",))
165172
else:
166173
part_names = (
167-
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
174+
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
168175
)
169176

170177
for part_name in part_names:
171178
if args.vocab_only:
172179
break
173180
print("gguf: loading model part '" + part_name + "'")
174-
model_part = torch.load(dir_model / part_name, map_location="cpu")
181+
model_part = safe_open(dir_model / part_name, framework="pt")
175182

176183
for name in model_part.keys():
177184
if "self_attn.rotary_emb.inv_freq" in name:
178185
continue
179-
data = model_part[name]
186+
data = model_part.get_tensor(name)
180187

181188
old_dtype = data.dtype
182189

0 commit comments

Comments
 (0)