diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f78dd..8b015926dc1b2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -74,7 +74,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, + thread_count: int = 2): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -123,7 +124,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, - split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard, + thread_count=thread_count) @classmethod def __init_subclass__(cls): @@ -5525,6 +5527,10 @@ def parse_args() -> argparse.Namespace: "--remote", action="store_true", help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", ) + parser.add_argument( + "-t", "--threads", type=int, default=2, + help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this. Defaults to 2.", + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5620,7 +5626,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=str(args.model) if args.remote else None) + remote_hf_model_id=str(args.model) if args.remote else None, + thread_count=args.threads) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 485550aad6da4..ea283c57fabcd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -5,6 +5,7 @@ import shutil import struct import tempfile +import threading from dataclasses import dataclass from enum import Enum, auto from math import prod @@ -12,6 +13,7 @@ from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait import numpy as np @@ -60,8 +62,63 @@ class WriterState(Enum): WEIGHTS = auto() +# To close files which were opened in thread-local context +# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer +# ref: https://github.com/python/cpython/issues/89502 +class _ThreadedOpenFiles: + files: dict[Path, BufferedWriter] + + def __init__(self): + self.files = {} + + def __del__(self): + for file in self.files.values(): + file.close() + + def __getitem__(self, key: Path, /) -> BufferedWriter: + if key not in self.files: + self.files[key] = open(key, "r+b") + return self.files[key] + + @classmethod + def init_thread_local(cls, local_data): + local_data.open_files = _ThreadedOpenFiles() + + +# Exit quickly instead of waiting +class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor): + def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None: + del exc_type, exc_val, exc_tb + self.shutdown(wait=False, cancel_futures=True) + return False + + +@dataclass +class _ThreadedTensorWriteInfo: + filename: Path + offset: int + post_pad: int + tensor: np.ndarray + bar: Any | None # optional tqdm progress bar + + def write_chunk(self, open_files: _ThreadedOpenFiles): + # This is called from a thread pool, + # and each thread should have its own file handle per output file + # so that they can have different seek locations. + f = open_files[self.filename] + + f.seek(self.offset) + f.write(self.tensor.data) + if self.post_pad > 0: + f.write(bytes([0] * self.post_pad)) + if self.bar is not None: + self.bar.update(self.tensor.nbytes) + + class GGUFWriter: fout: list[BufferedWriter] | None + filenames: list[Path] | None + thread_count: int path: Path | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None tensors: list[dict[str, TensorInfo]] @@ -83,7 +140,8 @@ class GGUFWriter: def __init__( self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, + thread_count: int = 2, ): self.fout = None self.path = Path(path) if path else None @@ -98,6 +156,7 @@ def __init__( self.split_max_size = split_max_size self.dry_run = dry_run self.small_first_shard = small_first_shard + self.thread_count = thread_count logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) @@ -173,6 +232,7 @@ def open_output_file(self, path: Path | None = None) -> None: if self.path is not None: filenames = self.print_plan() + self.filenames = filenames self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY @@ -424,40 +484,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() assert self.fout is not None + assert self.filenames is not None for fout in self.fout: self.write_padding(fout, fout.tell()) if self.temp_file is None: - shard_bar = None bar = None + # Initial file offsets before writing the tensor data + offsets: list[int] = [fout.tell() for fout in self.fout] if progress: + # TODO: add back the shard bar to show which shard is being written when single-threaded from tqdm import tqdm total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) - if len(self.fout) > 1: - shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): - if shard_bar is not None: - shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") - total = sum(ti.nbytes for ti in tensors.values()) - shard_bar.reset(total=(total if total > 0 else None)) - - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(fout) - if shard_bar is not None: - shard_bar.update(ti.nbytes) - if bar is not None: - bar.update(ti.nbytes) - self.write_padding(fout, ti.nbytes) - ti.tensor = None + # Allow opening the files only once per worker + local_data = threading.local() + + # Unit of work + def thread_write_tensor(tensor: _ThreadedTensorWriteInfo): + tensor.write_chunk(local_data.open_files) + + with _InterruptibleThreadPoolExecutor( + max_workers=self.thread_count, + initializer=_ThreadedOpenFiles.init_thread_local, + initargs=(local_data,), + ) as executor: + + futures: list[Future] = [] + + # Fill the tensor queue with all the pending tensor writes + for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): + offset = offsets[i] + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + start_offset = offset + nbytes = ti.tensor.nbytes + offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) + padding = offset - (start_offset + nbytes) + futures.append( + executor.submit( + thread_write_tensor, + _ThreadedTensorWriteInfo( + filename=filename, + offset=start_offset, + post_pad=padding, + tensor=ti.tensor, + bar=bar, + ), + ) + ) + ti.tensor = None # avoid keeping a reference to written tensors + + # FIXME: there's still some weird behavior with KeyboardInterrupt + # not being able to interrupt a future mid-execution + done, not_done = wait(futures, return_when=FIRST_EXCEPTION) + exc = None + if any(f for f in done + if not f.cancelled() and (exc := f.exception()) is not None): + raise RuntimeError("Error writing tensors") from exc + elif len(not_done) != 0: + raise RuntimeError("Not all tensors were written") + + del local_data else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index f9bcadae0224b..e01b5b050b788 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -220,4 +220,9 @@ def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) return eager.tofile(*args, **kwargs) + @property + def data(self): + eager = LazyNumpyTensor.to_eager(self) + return eager.data + # TODO: __array_function__ diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index e5251aef8c832..0734b9f25d2ac 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -5,6 +5,14 @@ import os import json +import time +import logging + +import requests +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st @dataclass class RemoteTensor: + name: str dtype: str shape: tuple[int, ...] offset_start: int @@ -82,9 +91,30 @@ class RemoteTensor: url: str def data(self) -> bytearray: - # TODO: handle request errors (maybe with limited retries?) - # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable - data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + data = None + MAX_RETRIES = 8 + for i in range(MAX_RETRIES): + try: + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray( + SafetensorRemote.get_data_by_range( + url=self.url, start=self.offset_start, size=self.size + ) + ) + except ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ContentDecodingError, + requests.exceptions.ConnectionError, + ) as e: + if i == MAX_RETRIES - 1: + raise RuntimeError(f"Failed to download tensor {self.name}") from e + logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}") + time.sleep(2 * i + 1) # 1 3 5 7 9 11 13 + continue + + if data is None: + raise RuntimeError(f"Failed to download tensor {self.name}") + return data @@ -169,7 +199,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: offset_start_relative, offset_end_relative = meta["data_offsets"] size = offset_end_relative - offset_start_relative offset_start = data_start_offset + offset_start_relative - res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + res[name] = RemoteTensor( + name=name, + dtype=dtype, + shape=tuple(shape), + offset_start=offset_start, + size=size, + url=url, + ) except KeyError as e: raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") @@ -217,8 +254,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: Get raw byte data from a remote file by range. If size is not specified, it will read the entire file. """ - import requests - from urllib.parse import urlparse parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: @@ -239,9 +274,6 @@ def check_file_exist(cls, url: str) -> bool: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ - import requests - from urllib.parse import urlparse - parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}")