Skip to content

Commit 17c3c6d

Browse files
committed
mod: add gguf.PoolingType.LAST and default to find PoolingType and write to gguf file. Dump pooling type while converting and load default pooling type from model.
1 parent b8874fe commit 17c3c6d

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

convert_hf_to_gguf.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,33 @@ def set_gguf_parameters(self):
260260
self.gguf_writer.add_file_type(self.ftype)
261261
logger.info(f"gguf: file type = {self.ftype}")
262262

263+
# get pooling path
264+
pooling_path = None
265+
self.pooling_type = gguf.PoolingType.NONE
266+
module_path = self.dir_model / "modules.json"
267+
if module_path.is_file():
268+
with open(module_path, encoding="utf-8") as f:
269+
modules = json.load(f)
270+
for mod in modules:
271+
if mod["type"] == "sentence_transformers.models.Pooling":
272+
pooling_path = mod["path"]
273+
break
274+
275+
# get pooling type
276+
if pooling_path is not None:
277+
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
278+
pooling = json.load(f)
279+
if pooling["pooling_mode_mean_tokens"]:
280+
self.pooling_type = gguf.PoolingType.MEAN
281+
elif pooling["pooling_mode_cls_token"]:
282+
self.pooling_type = gguf.PoolingType.CLS
283+
elif pooling["pooling_mode_lasttoken"]:
284+
self.pooling_type = gguf.PoolingType.LAST
285+
else:
286+
logger.warning("Only [MEAN|CLS|LAST] pooling types supported, default NONE")
287+
self.gguf_writer.add_pooling_type(self.pooling_type)
288+
logger.info(f"gguf: pooling type = {self.pooling_type}")
289+
263290
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
264291
del bid # unused
265292

@@ -2210,7 +2237,6 @@ def set_gguf_parameters(self):
22102237
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
22112238
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
22122239

2213-
22142240
@Model.register("Qwen2VLForConditionalGeneration")
22152241
class Qwen2VLModel(Model):
22162242
model_arch = gguf.MODEL_ARCH.QWEN2VL
@@ -2957,29 +2983,6 @@ def set_gguf_parameters(self):
29572983
super().set_gguf_parameters()
29582984
self.gguf_writer.add_causal_attention(False)
29592985

2960-
# get pooling path
2961-
pooling_path = None
2962-
module_path = self.dir_model / "modules.json"
2963-
if module_path.is_file():
2964-
with open(module_path, encoding="utf-8") as f:
2965-
modules = json.load(f)
2966-
for mod in modules:
2967-
if mod["type"] == "sentence_transformers.models.Pooling":
2968-
pooling_path = mod["path"]
2969-
break
2970-
2971-
# get pooling type
2972-
if pooling_path is not None:
2973-
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
2974-
pooling = json.load(f)
2975-
if pooling["pooling_mode_mean_tokens"]:
2976-
pooling_type = gguf.PoolingType.MEAN
2977-
elif pooling["pooling_mode_cls_token"]:
2978-
pooling_type = gguf.PoolingType.CLS
2979-
else:
2980-
raise NotImplementedError("Only MEAN and CLS pooling types supported")
2981-
self.gguf_writer.add_pooling_type(pooling_type)
2982-
29832986
def set_vocab(self):
29842987
tokens, toktypes, tokpre = self.get_vocab_base()
29852988
self.vocab_size = len(tokens)

gguf-py/gguf/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,7 @@ class PoolingType(IntEnum):
16161616
NONE = 0
16171617
MEAN = 1
16181618
CLS = 2
1619+
LAST = 3
16191620

16201621

16211622
class GGMLQuantizationType(IntEnum):

src/llama-model.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
412412
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
413413
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
414414
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
415+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
415416

416417
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
417418
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);

0 commit comments

Comments
 (0)