|
17 | 17 | import importlib
|
18 | 18 | import inspect
|
19 | 19 | import os
|
| 20 | +from array import array |
20 | 21 | from collections import OrderedDict
|
21 | 22 | from pathlib import Path
|
22 | 23 | from typing import List, Optional, Union
|
|
26 | 27 | from huggingface_hub.utils import EntryNotFoundError
|
27 | 28 |
|
28 | 29 | from ..utils import (
|
| 30 | + GGUF_FILE_EXTENSION, |
29 | 31 | SAFE_WEIGHTS_INDEX_NAME,
|
30 | 32 | SAFETENSORS_FILE_EXTENSION,
|
31 | 33 | WEIGHTS_INDEX_NAME,
|
32 | 34 | _add_variant,
|
33 | 35 | _get_model_file,
|
34 | 36 | deprecate,
|
35 | 37 | is_accelerate_available,
|
| 38 | + is_gguf_available, |
| 39 | + is_torch_available, |
36 | 40 | is_torch_version,
|
37 | 41 | logging,
|
38 | 42 | )
|
@@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
139 | 143 | file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
140 | 144 | if file_extension == SAFETENSORS_FILE_EXTENSION:
|
141 | 145 | return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
| 146 | + elif file_extension == GGUF_FILE_EXTENSION: |
| 147 | + return load_gguf_checkpoint(checkpoint_file) |
142 | 148 | else:
|
143 | 149 | weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
144 | 150 | return torch.load(
|
@@ -211,13 +217,14 @@ def load_model_dict_into_meta(
|
211 | 217 | set_module_kwargs["dtype"] = dtype
|
212 | 218 |
|
213 | 219 | # bnb params are flattened.
|
| 220 | + # gguf quants have a different shape based on the type of quantization applied |
214 | 221 | if empty_state_dict[param_name].shape != param.shape:
|
215 | 222 | if (
|
216 | 223 | is_quantized
|
217 | 224 | and hf_quantizer.pre_quantized
|
218 | 225 | and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
219 | 226 | ):
|
220 |
| - hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) |
| 227 | + hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) |
221 | 228 | else:
|
222 | 229 | model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
223 | 230 | raise ValueError(
|
@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
|
396 | 403 | index_file = None
|
397 | 404 |
|
398 | 405 | return index_file
|
| 406 | + |
| 407 | + |
| 408 | +def _gguf_parse_value(_value, data_type): |
| 409 | + if not isinstance(data_type, list): |
| 410 | + data_type = [data_type] |
| 411 | + if len(data_type) == 1: |
| 412 | + data_type = data_type[0] |
| 413 | + array_data_type = None |
| 414 | + else: |
| 415 | + if data_type[0] != 9: |
| 416 | + raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") |
| 417 | + data_type, array_data_type = data_type |
| 418 | + |
| 419 | + if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: |
| 420 | + _value = int(_value[0]) |
| 421 | + elif data_type in [6, 12]: |
| 422 | + _value = float(_value[0]) |
| 423 | + elif data_type in [7]: |
| 424 | + _value = bool(_value[0]) |
| 425 | + elif data_type in [8]: |
| 426 | + _value = array("B", list(_value)).tobytes().decode() |
| 427 | + elif data_type in [9]: |
| 428 | + _value = _gguf_parse_value(_value, array_data_type) |
| 429 | + return _value |
| 430 | + |
| 431 | + |
| 432 | +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): |
| 433 | + """ |
| 434 | + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config |
| 435 | + attributes. |
| 436 | +
|
| 437 | + Args: |
| 438 | + gguf_checkpoint_path (`str`): |
| 439 | + The path the to GGUF file to load |
| 440 | + return_tensors (`bool`, defaults to `True`): |
| 441 | + Whether to read the tensors from the file and return them. Not doing so is faster and only loads the |
| 442 | + metadata in memory. |
| 443 | + """ |
| 444 | + |
| 445 | + if is_gguf_available() and is_torch_available(): |
| 446 | + import gguf |
| 447 | + from gguf import GGUFReader |
| 448 | + |
| 449 | + from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter |
| 450 | + else: |
| 451 | + logger.error( |
| 452 | + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " |
| 453 | + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." |
| 454 | + ) |
| 455 | + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") |
| 456 | + |
| 457 | + reader = GGUFReader(gguf_checkpoint_path) |
| 458 | + |
| 459 | + parsed_parameters = {} |
| 460 | + for tensor in reader.tensors: |
| 461 | + name = tensor.name |
| 462 | + quant_type = tensor.tensor_type |
| 463 | + |
| 464 | + # if the tensor is a torch supported dtype do not use GGUFParameter |
| 465 | + is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] |
| 466 | + if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES: |
| 467 | + _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES]) |
| 468 | + raise ValueError( |
| 469 | + ( |
| 470 | + f"{name} has a quantization type: {str(quant_type)} which is unsupported." |
| 471 | + "\n\nCurrently the following quantization types are supported: \n\n" |
| 472 | + f"{_supported_quants_str}" |
| 473 | + "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers" |
| 474 | + ) |
| 475 | + ) |
| 476 | + |
| 477 | + weights = torch.from_numpy(tensor.data.copy()) |
| 478 | + parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights |
| 479 | + |
| 480 | + return parsed_parameters |
0 commit comments