45
45
except ImportError :
46
46
USE_XFORMERS_OPS = False
47
47
48
- # These token ids cannot be retrieved from model config
49
- # so we hardcode them here.
50
- PIXTRAL_12B_IMAGE_BREAK_ID = 12
51
- PIXTRAL_12B_IMAGE_END_ID = 13
52
- PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
53
- PIXTRAL_LARGE_IMAGE_END_ID = 15
54
-
55
48
56
49
def get_max_pixtral_image_tokens (ctx : InputContext ):
57
50
tokenizer = cached_get_tokenizer (
@@ -201,6 +194,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201
194
if key in dataclass_fields
202
195
}
203
196
197
+ if not ("image_break_token_id" in vision_args
198
+ and "image_end_token_id" in vision_args ):
199
+ raise ValueError (
200
+ "'image_break_token_id' and 'image_end_token_id' not found "
201
+ "in the vision_encoder arguments. Please download the latest "
202
+ "version of 'params.json' from the model repository." )
203
+
204
204
self .vision_args = VisionEncoderArgs (** vision_args )
205
205
206
206
# init MistralForCausalLM
@@ -240,9 +240,8 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
240
240
241
241
# NOTE: Image embeddings are split into separate tensors for each image
242
242
# by the indices of `[IMG_END]` token.
243
- image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID ) | (
244
- image_tokens == PIXTRAL_LARGE_IMAGE_END_ID )
245
- split_indices = torch .where (image_end_condition )[0 ] + 1
243
+ image_end_mask = image_tokens == self .vision_args .image_end_token_id
244
+ split_indices = torch .where (image_end_mask )[0 ] + 1
246
245
if len (split_indices ) <= 1 :
247
246
# Do not split, return as tensor of shape [1, fs, hs]
248
247
return image_embeds .unsqueeze (0 )
@@ -265,10 +264,8 @@ def get_input_embeddings(
265
264
inputs_embeds = merge_multimodal_embeddings (
266
265
input_ids , inputs_embeds , multimodal_embeddings , [
267
266
self .vision_args .image_token_id ,
268
- PIXTRAL_12B_IMAGE_END_ID ,
269
- PIXTRAL_12B_IMAGE_BREAK_ID ,
270
- PIXTRAL_LARGE_IMAGE_BREAK_ID ,
271
- PIXTRAL_LARGE_IMAGE_END_ID ,
267
+ self .vision_args .image_break_token_id ,
268
+ self .vision_args .image_end_token_id ,
272
269
])
273
270
return inputs_embeds
274
271
@@ -409,6 +406,8 @@ class VisionEncoderArgs:
409
406
num_attention_heads : int
410
407
rope_theta : float # for rope-2D
411
408
image_token_id : int
409
+ image_break_token_id : int
410
+ image_end_token_id : int
412
411
adapter_bias : bool = True
413
412
414
413
0 commit comments