Skip to content

Commit 95078cc

Browse files
authored
convert: add ability to convert safetensors files (#1276)
* when loading a safetensors file, ignore the metadata header * check for safetensors files first, and only use PyTorch versions when safetensors aren't available
1 parent 1f48b0a commit 95078cc

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

convert.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def load() -> UnquantizedTensor:
766766
return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
767767
description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
768768
return LazyTensor(load, shape, data_type, description)
769-
model = {name: convert(info) for (name, info) in header.items()}
769+
model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
770770
return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
771771

772772

@@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus:
10511051
'''Load a model of any supported format.'''
10521052
# Be extra-friendly and accept either a file or a directory:
10531053
if path.is_dir():
1054-
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
1055-
files = [file for glob in globs for file in path.glob(glob)]
1054+
# Check if it's a set of safetensors files first
1055+
files = list(path.glob("model-00001-of-*.safetensors"))
1056+
if not files:
1057+
# Try the PyTorch patterns too, with lower priority
1058+
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
1059+
files = [file for glob in globs for file in path.glob(glob)]
10561060
if not files:
10571061
# Try GGML too, but with lower priority, since if both a non-GGML
10581062
# model and a GGML model exist in the same directory, we assume the

0 commit comments

Comments
 (0)