Skip to content

Commit ee52225

Browse files
authored
convert-hf : support direct Q8_0 conversion (#7234)
* convert-hf : support q8_0 conversion * convert-hf : add missing ftype This was messing with the checksums otherwise. * convert-hf : add missing ftype to Baichuan and Xverse I didn't notice these on my first pass.
1 parent 614d3b9 commit ee52225

File tree

5 files changed

+169
-58
lines changed

5 files changed

+169
-58
lines changed

convert-hf-to-gguf.py

+28-44
Original file line numberDiff line numberDiff line change
@@ -240,23 +240,6 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
240240
return False
241241

242242
def write_tensors(self):
243-
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
244-
def np_fp32_to_bf16(n: np.ndarray):
245-
# force nan to quiet
246-
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
247-
# flush subnormals to zero
248-
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
249-
# round to nearest even
250-
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
251-
return n.astype(np.int16)
252-
253-
# Doing this row-wise is much, much faster than element-wise, hence the signature
254-
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
255-
if self.lazy:
256-
# TODO: find a way to implicitly wrap np.vectorize functions
257-
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
258-
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
259-
260243
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
261244

262245
for name, data_torch in self.get_tensors():
@@ -309,27 +292,31 @@ def np_fp32_to_bf16(n: np.ndarray):
309292
))
310293

311294
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
312-
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
295+
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
296+
data = gguf.quantize_bf16(data)
297+
assert data.dtype == np.int16
298+
data_qtype = gguf.GGMLQuantizationType.BF16
299+
300+
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
301+
data = gguf.quantize_q8_0(data)
302+
assert data.dtype == np.uint8
303+
data_qtype = gguf.GGMLQuantizationType.Q8_0
304+
305+
else: # default to float16 for quantized tensors
313306
if data_dtype != np.float16:
314307
data = data.astype(np.float16)
315308
data_qtype = gguf.GGMLQuantizationType.F16
316309

317-
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
318-
if data_dtype != np.float32:
319-
data = data.astype(np.float32)
320-
data = v_fp32_to_bf16(data.view(np.int32))
321-
assert data.dtype == np.int16
322-
data_qtype = gguf.GGMLQuantizationType.BF16
323-
324-
else: # by default, convert to float32
310+
if data_qtype is None: # by default, convert to float32
325311
if data_dtype != np.float32:
326312
data = data.astype(np.float32)
327313
data_qtype = gguf.GGMLQuantizationType.F32
328314

329-
assert data_qtype is not None
330-
315+
block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
331316
# reverse shape to make it similar to the internal ggml dimension order
332-
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
317+
shape_str = f"""{{{', '.join(str(n) for n in reversed(
318+
(*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
319+
)}}}"""
333320

334321
# n_dims is implicit in the shape
335322
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
@@ -859,6 +846,7 @@ def set_gguf_parameters(self):
859846
self.gguf_writer.add_head_count(head_count)
860847
self.gguf_writer.add_head_count_kv(head_count_kv)
861848
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
849+
self.gguf_writer.add_file_type(self.ftype)
862850

863851
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
864852
if self.hparams["rope_scaling"].get("type") == "linear":
@@ -981,6 +969,7 @@ def set_gguf_parameters(self):
981969
self.gguf_writer.add_head_count(head_count)
982970
self.gguf_writer.add_head_count_kv(head_count_kv)
983971
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
972+
self.gguf_writer.add_file_type(self.ftype)
984973

985974
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
986975
if self.hparams["rope_scaling"].get("type") == "linear":
@@ -1215,6 +1204,7 @@ def set_gguf_parameters(self):
12151204
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
12161205
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
12171206
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
1207+
self.gguf_writer.add_file_type(self.ftype)
12181208

12191209
_q_norms: list[dict[str, Tensor]] | None = None
12201210
_k_norms: list[dict[str, Tensor]] | None = None
@@ -1591,6 +1581,7 @@ def set_gguf_parameters(self):
15911581
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
15921582
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
15931583
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
1584+
self.gguf_writer.add_file_type(self.ftype)
15941585

15951586

15961587
@Model.register("Qwen2ForCausalLM")
@@ -1828,6 +1819,7 @@ def set_gguf_parameters(self):
18281819
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
18291820
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
18301821
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
1822+
self.gguf_writer.add_file_type(self.ftype)
18311823

18321824
def shuffle_attn_q_weight(self, data_torch):
18331825
assert data_torch.size() == (5120, 5120)
@@ -2007,6 +1999,7 @@ def set_gguf_parameters(self):
20071999
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
20082000
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
20092001
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
2002+
self.gguf_writer.add_file_type(self.ftype)
20102003

20112004
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
20122005
num_heads = self.hparams["num_attention_heads"]
@@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase):
24152408
def numpy(self) -> gguf.LazyNumpyTensor:
24162409
dtype = self._dtype_map[self.dtype]
24172410
return gguf.LazyNumpyTensor(
2418-
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
2411+
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
24192412
lazy=self._lazy,
24202413
args=(self,),
24212414
func=(lambda s: s[0].numpy())
24222415
)
24232416

24242417
@classmethod
2425-
def eager_to_meta(cls, t: Tensor) -> Tensor:
2426-
if t.is_meta:
2427-
return t
2428-
return t.detach().to("meta")
2429-
2430-
@classmethod
2431-
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
2432-
m = m.detach()
2433-
if not m.is_meta:
2434-
m = m.to("meta")
2435-
m.dtype = dtype
2436-
return m
2418+
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
2419+
return torch.empty(size=shape, dtype=dtype, device="meta")
24372420

24382421
@classmethod
24392422
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace:
24642447
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
24652448
)
24662449
parser.add_argument(
2467-
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
2468-
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
2450+
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
2451+
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
24692452
)
24702453
parser.add_argument(
24712454
"--bigendian", action="store_true",
@@ -2523,6 +2506,7 @@ def main() -> None:
25232506
"f32": gguf.LlamaFileType.ALL_F32,
25242507
"f16": gguf.LlamaFileType.MOSTLY_F16,
25252508
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
2509+
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
25262510
"auto": gguf.LlamaFileType.GUESSED,
25272511
}
25282512

gguf-py/gguf/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .lazy import *
33
from .gguf_reader import *
44
from .gguf_writer import *
5+
from .quants import *
56
from .tensor_mapping import *
67
from .vocab import *

gguf-py/gguf/gguf_writer.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
from .constants import (
16+
GGML_QUANT_SIZES,
1617
GGUF_DEFAULT_ALIGNMENT,
1718
GGUF_MAGIC,
1819
GGUF_VERSION,
@@ -195,7 +196,7 @@ def ggml_pad(x: int, n: int) -> int:
195196
return ((x + n - 1) // n) * n
196197

197198
def add_tensor_info(
198-
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
199+
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
199200
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
200201
) -> None:
201202
if self.state is not WriterState.EMPTY:
@@ -208,10 +209,6 @@ def add_tensor_info(
208209
encoded_name = name.encode("utf-8")
209210
self.ti_data += self._pack("Q", len(encoded_name))
210211
self.ti_data += encoded_name
211-
n_dims = len(tensor_shape)
212-
self.ti_data += self._pack("I", n_dims)
213-
for i in range(n_dims):
214-
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
215212
if raw_dtype is None:
216213
if tensor_dtype == np.float16:
217214
dtype = GGMLQuantizationType.F16
@@ -231,6 +228,15 @@ def add_tensor_info(
231228
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
232229
else:
233230
dtype = raw_dtype
231+
if tensor_dtype == np.uint8:
232+
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
233+
if tensor_shape[-1] % type_size != 0:
234+
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
235+
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
236+
n_dims = len(tensor_shape)
237+
self.ti_data += self._pack("I", n_dims)
238+
for i in range(n_dims):
239+
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
234240
self.ti_data += self._pack("I", dtype)
235241
self.ti_data += self._pack("Q", self.offset_tensor)
236242
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)

gguf-py/gguf/lazy.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import deque
77

88
import numpy as np
9+
from numpy._typing import _Shape
910
from numpy.typing import DTypeLike
1011

1112

@@ -110,7 +111,7 @@ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
110111
return o
111112

112113
@classmethod
113-
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
114+
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
114115
def wrapped_fn(*args, **kwargs):
115116
if kwargs is None:
116117
kwargs = {}
@@ -130,9 +131,14 @@ def wrapped_fn(*args, **kwargs):
130131
res = args[0]
131132
assert isinstance(res, cls)
132133
res = res._meta
133-
# allow operations to override the dtype
134+
# allow operations to override the dtype and shape
134135
if meta_noop is not True:
135-
res = cls.meta_with_dtype(res, meta_noop)
136+
if isinstance(meta_noop, tuple):
137+
dtype, shape = meta_noop
138+
assert callable(shape)
139+
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
140+
else:
141+
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
136142

137143
if isinstance(res, cls._tensor_type):
138144
def collect_replace(t: LazyBase):
@@ -168,7 +174,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:
168174
while _t._data is None:
169175
lt = _t._lazy.popleft()
170176
if lt._data is not None:
171-
raise ValueError(f"{lt} did not belong in the lazy queue")
177+
# Lazy tensor did not belong in the lazy queue.
178+
# Weirdly only happens with Bloom models...
179+
# likely because tensors aren't unique in the queue.
180+
# The final output is still the same as in eager mode,
181+
# so it's safe to ignore this.
182+
continue
172183
assert lt._func is not None
173184
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
174185
lt._data = lt._func(lt._args)
@@ -183,12 +194,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:
183194

184195
@classmethod
185196
def eager_to_meta(cls, t: Any) -> Any:
186-
return cls.meta_with_dtype(t, t.dtype)
197+
return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
187198

188199
# must be overridden, meta tensor init is backend-specific
189200
@classmethod
190201
@abstractmethod
191-
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
202+
def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
192203

193204
@classmethod
194205
def from_eager(cls, t: Any) -> Any:
@@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase):
205216
_tensor_type = np.ndarray
206217

207218
@classmethod
208-
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
219+
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
209220
# The initial idea was to use np.nan as the fill value,
210221
# but non-float types like np.int16 can't use that.
211222
# So zero it is.
212223
cheat = np.zeros(1, dtype)
213-
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
224+
return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
214225

215226
def astype(self, dtype, *args, **kwargs):
216-
meta = type(self).meta_with_dtype(self._meta, dtype)
227+
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
217228
full_args = (self, dtype,) + args
218229
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
219230
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))

0 commit comments

Comments
 (0)