Skip to content

Commit b7dcc00

Browse files
authored
[Model] Remove hardcoded image tokens ids from Pixtral (#11582)
Signed-off-by: Roger Wang <[email protected]>
1 parent d34be24 commit b7dcc00

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

vllm/model_executor/models/pixtral.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,6 @@
4545
except ImportError:
4646
USE_XFORMERS_OPS = False
4747

48-
# These token ids cannot be retrieved from model config
49-
# so we hardcode them here.
50-
PIXTRAL_12B_IMAGE_BREAK_ID = 12
51-
PIXTRAL_12B_IMAGE_END_ID = 13
52-
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
53-
PIXTRAL_LARGE_IMAGE_END_ID = 15
54-
5548

5649
def get_max_pixtral_image_tokens(ctx: InputContext):
5750
tokenizer = cached_get_tokenizer(
@@ -201,6 +194,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201194
if key in dataclass_fields
202195
}
203196

197+
if not ("image_break_token_id" in vision_args
198+
and "image_end_token_id" in vision_args):
199+
raise ValueError(
200+
"'image_break_token_id' and 'image_end_token_id' not found "
201+
"in the vision_encoder arguments. Please download the latest "
202+
"version of 'params.json' from the model repository.")
203+
204204
self.vision_args = VisionEncoderArgs(**vision_args)
205205

206206
# init MistralForCausalLM
@@ -240,9 +240,8 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
240240

241241
# NOTE: Image embeddings are split into separate tensors for each image
242242
# by the indices of `[IMG_END]` token.
243-
image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
244-
image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
245-
split_indices = torch.where(image_end_condition)[0] + 1
243+
image_end_mask = image_tokens == self.vision_args.image_end_token_id
244+
split_indices = torch.where(image_end_mask)[0] + 1
246245
if len(split_indices) <= 1:
247246
# Do not split, return as tensor of shape [1, fs, hs]
248247
return image_embeds.unsqueeze(0)
@@ -265,10 +264,8 @@ def get_input_embeddings(
265264
inputs_embeds = merge_multimodal_embeddings(
266265
input_ids, inputs_embeds, multimodal_embeddings, [
267266
self.vision_args.image_token_id,
268-
PIXTRAL_12B_IMAGE_END_ID,
269-
PIXTRAL_12B_IMAGE_BREAK_ID,
270-
PIXTRAL_LARGE_IMAGE_BREAK_ID,
271-
PIXTRAL_LARGE_IMAGE_END_ID,
267+
self.vision_args.image_break_token_id,
268+
self.vision_args.image_end_token_id,
272269
])
273270
return inputs_embeds
274271

@@ -409,6 +406,8 @@ class VisionEncoderArgs:
409406
num_attention_heads: int
410407
rope_theta: float # for rope-2D
411408
image_token_id: int
409+
image_break_token_id: int
410+
image_end_token_id: int
412411
adapter_bias: bool = True
413412

414413

0 commit comments

Comments
 (0)