Skip to content

convert : write tensors in parallel #12837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
thread_count: int = 2):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")

Expand Down Expand Up @@ -123,7 +124,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:

# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
thread_count=thread_count)

@classmethod
def __init_subclass__(cls):
Expand Down Expand Up @@ -5525,6 +5527,10 @@ def parse_args() -> argparse.Namespace:
"--remote", action="store_true",
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
)
parser.add_argument(
"-t", "--threads", type=int, default=2,
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this. Defaults to 2.",
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
Expand Down Expand Up @@ -5620,7 +5626,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None)
remote_hf_model_id=str(args.model) if args.remote else None,
thread_count=args.threads)

if args.vocab_only:
logger.info("Exporting model vocab...")
Expand Down
138 changes: 117 additions & 21 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import shutil
import struct
import tempfile
import threading
from dataclasses import dataclass
from enum import Enum, auto
from math import prod
from pathlib import Path
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait

import numpy as np

Expand Down Expand Up @@ -60,8 +62,63 @@ class WriterState(Enum):
WEIGHTS = auto()


# To close files which were opened in thread-local context
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
# ref: https://github.com/python/cpython/issues/89502
class _ThreadedOpenFiles:
files: dict[Path, BufferedWriter]

def __init__(self):
self.files = {}

def __del__(self):
for file in self.files.values():
file.close()

def __getitem__(self, key: Path, /) -> BufferedWriter:
if key not in self.files:
self.files[key] = open(key, "r+b")
return self.files[key]

@classmethod
def init_thread_local(cls, local_data):
local_data.open_files = _ThreadedOpenFiles()


# Exit quickly instead of waiting
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
del exc_type, exc_val, exc_tb
self.shutdown(wait=False, cancel_futures=True)
return False


@dataclass
class _ThreadedTensorWriteInfo:
filename: Path
offset: int
post_pad: int
tensor: np.ndarray
bar: Any | None # optional tqdm progress bar

def write_chunk(self, open_files: _ThreadedOpenFiles):
# This is called from a thread pool,
# and each thread should have its own file handle per output file
# so that they can have different seek locations.
f = open_files[self.filename]

f.seek(self.offset)
f.write(self.tensor.data)
if self.post_pad > 0:
f.write(bytes([0] * self.post_pad))
if self.bar is not None:
self.bar.update(self.tensor.nbytes)


class GGUFWriter:
fout: list[BufferedWriter] | None
filenames: list[Path] | None
thread_count: int
path: Path | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo]]
Expand All @@ -83,7 +140,8 @@ class GGUFWriter:

def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
thread_count: int = 2,
):
self.fout = None
self.path = Path(path) if path else None
Expand All @@ -98,6 +156,7 @@ def __init__(
self.split_max_size = split_max_size
self.dry_run = dry_run
self.small_first_shard = small_first_shard
self.thread_count = thread_count
logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little",
))
Expand Down Expand Up @@ -173,6 +232,7 @@ def open_output_file(self, path: Path | None = None) -> None:

if self.path is not None:
filenames = self.print_plan()
self.filenames = filenames
self.fout = [open(filename, "wb") for filename in filenames]
self.state = WriterState.EMPTY

Expand Down Expand Up @@ -424,40 +484,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file()

assert self.fout is not None
assert self.filenames is not None

for fout in self.fout:
self.write_padding(fout, fout.tell())

if self.temp_file is None:
shard_bar = None
bar = None
# Initial file offsets before writing the tensor data
offsets: list[int] = [fout.tell() for fout in self.fout]

if progress:
# TODO: add back the shard bar to show which shard is being written when single-threaded
from tqdm import tqdm

total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())

if len(self.fout) > 1:
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)

for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
if shard_bar is not None:
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
total = sum(ti.nbytes for ti in tensors.values())
shard_bar.reset(total=(total if total > 0 else None))

# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
ti.tensor.tofile(fout)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
if bar is not None:
bar.update(ti.nbytes)
self.write_padding(fout, ti.nbytes)
ti.tensor = None
# Allow opening the files only once per worker
local_data = threading.local()

# Unit of work
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
tensor.write_chunk(local_data.open_files)

with _InterruptibleThreadPoolExecutor(
max_workers=self.thread_count,
initializer=_ThreadedOpenFiles.init_thread_local,
initargs=(local_data,),
) as executor:

futures: list[Future] = []

# Fill the tensor queue with all the pending tensor writes
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
offset = offsets[i]

# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
start_offset = offset
nbytes = ti.tensor.nbytes
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
padding = offset - (start_offset + nbytes)
futures.append(
executor.submit(
thread_write_tensor,
_ThreadedTensorWriteInfo(
filename=filename,
offset=start_offset,
post_pad=padding,
tensor=ti.tensor,
bar=bar,
),
)
)
ti.tensor = None # avoid keeping a reference to written tensors

# FIXME: there's still some weird behavior with KeyboardInterrupt
# not being able to interrupt a future mid-execution
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
exc = None
if any(f for f in done
if not f.cancelled() and (exc := f.exception()) is not None):
raise RuntimeError("Error writing tensors") from exc
elif len(not_done) != 0:
raise RuntimeError("Not all tensors were written")

del local_data
else:
self.temp_file.seek(0)

Expand Down
5 changes: 5 additions & 0 deletions gguf-py/gguf/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,9 @@ def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)
return eager.tofile(*args, **kwargs)

@property
def data(self):
eager = LazyNumpyTensor.to_eager(self)
return eager.data

# TODO: __array_function__
50 changes: 41 additions & 9 deletions gguf-py/gguf/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

import os
import json
import time
import logging

import requests
from urllib.parse import urlparse


logger = logging.getLogger(__name__)


def fill_templated_filename(filename: str, output_type: str | None) -> str:
Expand Down Expand Up @@ -75,16 +83,38 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st

@dataclass
class RemoteTensor:
name: str
dtype: str
shape: tuple[int, ...]
offset_start: int
size: int
url: str

def data(self) -> bytearray:
# TODO: handle request errors (maybe with limited retries?)
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
data = None
MAX_RETRIES = 8
for i in range(MAX_RETRIES):
try:
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(
SafetensorRemote.get_data_by_range(
url=self.url, start=self.offset_start, size=self.size
)
)
except (
requests.exceptions.ChunkedEncodingError,
requests.exceptions.ContentDecodingError,
requests.exceptions.ConnectionError,
) as e:
if i == MAX_RETRIES - 1:
raise RuntimeError(f"Failed to download tensor {self.name}") from e
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
continue

if data is None:
raise RuntimeError(f"Failed to download tensor {self.name}")

return data


Expand Down Expand Up @@ -169,7 +199,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
res[name] = RemoteTensor(
name=name,
dtype=dtype,
shape=tuple(shape),
offset_start=offset_start,
size=size,
url=url,
)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")

Expand Down Expand Up @@ -217,8 +254,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
Get raw byte data from a remote file by range.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse

parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
Expand All @@ -239,9 +274,6 @@ def check_file_exist(cls, url: str) -> bool:
Check if a file exists at the given URL.
Returns True if the file exists, False otherwise.
"""
import requests
from urllib.parse import urlparse

parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")
Expand Down