Skip to content

Commit a245011

Browse files
authored
Add InternVL (2.5 MPO) (#35968)
* initial commit * add convert internvl * add first end-to-end working internvl * nit prompt and image proc * add working chat template * add conversion llama-based models * add tests * pass all tests * fix isort * fix modular after main merge * add video processing for internvl * add support for interlaced images and videos * Remove processing and config from modular, add more tests * add llama model tests * Modify processor for compatibility with refactored got ocr image processor * add comments in processor * Add docs and nits * change video processing to use custom sample_indices_fn * rebase and fix tests * add processor tests * Add changes Raushan review * Use the new attention interface for the vision model * nits * add support for custom video_load_backend * remove mention to InternVLTokenizer * refactor vision model to simplify logic * refactor processor for better readibility * fix copies * fix require av processor test * refactor internVL vision * Update processor and fix processing tests * fix docstring * update convert_weights for internvl3 * change image processor to fast by default * remove do_center_crop=True in convert_weights * force use_cache to True * push_to_hub before reloading * fix internVLVision for larger models * update convert weight for qk norm * fix convert_weights * fix eos_token_id in convert * update docs and integration tests * make modifs after review * fix wrong k_norm and reduce modular * change image_token_index to image_token_id * change checkpoint to OpenGVLab org * last nits * explicitely del self.num_key_value_groups * add extra special tokens
1 parent b0c6ff5 commit a245011

20 files changed

+4447
-5
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,8 @@
953953
title: InstructBLIP
954954
- local: model_doc/instructblipvideo
955955
title: InstructBlipVideo
956+
- local: model_doc/internvl
957+
title: InternVL
956958
- local: model_doc/janus
957959
title: Janus
958960
- local: model_doc/kosmos-2

docs/source/en/model_doc/internvl.md

+349
Large diffs are not rendered by default.

docs/source/en/perf_infer_gpu_one.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_m
244244

245245
### Benchmarks
246246

247-
FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.
247+
FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.
248248

249249
<hfoptions id="padded">
250250
<hfoption id="short sequence length">

src/transformers/image_utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from contextlib import redirect_stdout
1919
from dataclasses import dataclass
2020
from io import BytesIO
21-
from typing import TYPE_CHECKING, Callable, Optional, Union
21+
from typing import Callable, Optional, Union
2222
from urllib.parse import urlparse
2323

2424
import numpy as np
@@ -77,9 +77,8 @@
7777
pil_torch_interpolation_mapping = {}
7878

7979

80-
if TYPE_CHECKING:
81-
if is_torch_available():
82-
import torch
80+
if is_torch_available():
81+
import torch
8382

8483

8584
logger = logging.get_logger(__name__)
@@ -162,6 +161,15 @@ def is_valid_list_of_images(images: list):
162161
return images and all(is_valid_image(image) for image in images)
163162

164163

164+
def concatenate_list(input_list):
165+
if isinstance(input_list[0], list):
166+
return [item for sublist in input_list for item in sublist]
167+
elif isinstance(input_list[0], np.ndarray):
168+
return np.concatenate(input_list, axis=0)
169+
elif isinstance(input_list[0], torch.Tensor):
170+
return torch.cat(input_list, dim=0)
171+
172+
165173
def valid_images(imgs):
166174
# If we have an list of images, make sure every image is valid
167175
if isinstance(imgs, (list, tuple)):

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .informer import *
144144
from .instructblip import *
145145
from .instructblipvideo import *
146+
from .internvl import *
146147
from .jamba import *
147148
from .janus import *
148149
from .jetmoe import *

src/transformers/models/auto/configuration_auto.py

+5
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@
162162
("informer", "InformerConfig"),
163163
("instructblip", "InstructBlipConfig"),
164164
("instructblipvideo", "InstructBlipVideoConfig"),
165+
("internvl", "InternVLConfig"),
166+
("internvl_vision", "InternVLVisionConfig"),
165167
("jamba", "JambaConfig"),
166168
("janus", "JanusConfig"),
167169
("jetmoe", "JetMoeConfig"),
@@ -519,6 +521,8 @@
519521
("informer", "Informer"),
520522
("instructblip", "InstructBLIP"),
521523
("instructblipvideo", "InstructBlipVideo"),
524+
("internvl", "InternVL"),
525+
("internvl_vision", "InternVLVision"),
522526
("jamba", "Jamba"),
523527
("janus", "Janus"),
524528
("jetmoe", "JetMoe"),
@@ -797,6 +801,7 @@
797801
("chinese_clip_vision_model", "chinese_clip"),
798802
("rt_detr_resnet", "rt_detr"),
799803
("granitevision", "llava_next"),
804+
("internvl_vision", "internvl"),
800805
("qwen2_5_vl_text", "qwen2_5_vl"),
801806
("qwen2_vl_text", "qwen2_vl"),
802807
("sam_vision_model", "sam"),

src/transformers/models/auto/modeling_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
("ijepa", "IJepaModel"),
152152
("imagegpt", "ImageGPTModel"),
153153
("informer", "InformerModel"),
154+
("internvl_vision", "InternVLVisionModel"),
154155
("jamba", "JambaModel"),
155156
("janus", "JanusModel"),
156157
("jetmoe", "JetMoeModel"),
@@ -862,6 +863,7 @@
862863
("idefics2", "Idefics2ForConditionalGeneration"),
863864
("idefics3", "Idefics3ForConditionalGeneration"),
864865
("instructblip", "InstructBlipForConditionalGeneration"),
866+
("internvl", "InternVLForConditionalGeneration"),
865867
("janus", "JanusForConditionalGeneration"),
866868
("kosmos-2", "Kosmos2ForConditionalGeneration"),
867869
("llama4", "Llama4ForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
("idefics3", "Idefics3Processor"),
7676
("instructblip", "InstructBlipProcessor"),
7777
("instructblipvideo", "InstructBlipVideoProcessor"),
78+
("internvl", "InternVLProcessor"),
7879
("janus", "JanusProcessor"),
7980
("kosmos-2", "Kosmos2Processor"),
8081
("layoutlmv2", "LayoutLMv2Processor"),

src/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@
258258
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
259259
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
260260
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
261+
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
261262
(
262263
"jamba",
263264
(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_internvl import *
22+
from .modeling_internvl import *
23+
from .processing_internvl import *
24+
else:
25+
import sys
26+
27+
_file = globals()["__file__"]
28+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from ...configuration_utils import PretrainedConfig
18+
from ..auto import CONFIG_MAPPING, AutoConfig
19+
20+
21+
class InternVLVisionConfig(PretrainedConfig):
22+
r"""
23+
This is the configuration class to store the configuration of a [`InternVLVisionModel`]. It is used to instantiate an InternVLVisionModel
24+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
25+
a similar configuration to that of the InternVL3-1B.
26+
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
27+
28+
Args:
29+
hidden_size (`int`, *optional*, defaults to 1024):
30+
Dimensionality of the encoder layers and the pooler layer.
31+
num_hidden_layers (`int`, *optional*, defaults to 24):
32+
Number of hidden layers in the Transformer encoder.
33+
num_attention_heads (`int`, *optional*, defaults to 16):
34+
Number of attention heads for each attention layer in the Transformer encoder.
35+
attention_bias (`bool`, *optional*, defaults to `False`):
36+
Whether to add a bias to the queries, keys and values.
37+
use_qk_norm (`bool`, *optional*, defaults to `False`):
38+
Whether to apply normalization to the queries and keys before the attention operation.
39+
intermediate_size (`int`, *optional*, defaults to 4096):
40+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
41+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
42+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
44+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
45+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
46+
attention_dropout (`float`, *optional*, defaults to 0.0):
47+
Dropout probability for attention weights.
48+
projection_dropout (`float`, *optional*, defaults to 0.0):
49+
Dropout probability for the projection layer.
50+
initializer_range (`float`, *optional*, defaults to 0.02):
51+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
53+
The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`.
54+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
55+
The epsilon used by the layer normalization layers.
56+
image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`):
57+
The size (resolution) of each image.
58+
patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`):
59+
The size (resolution) of each patch.
60+
num_channels (`int`, *optional*, defaults to 3):
61+
The number of input channels.
62+
use_mask_token (`bool`, *optional*, defaults to `False`):
63+
Whether to use a mask token for masked image modeling.
64+
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
65+
Whether to use BERT-style absolute position embeddings.
66+
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
67+
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
68+
use_mean_pooling (`bool`, *optional*, defaults to `True`):
69+
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
70+
CLS token, before applying the classification head.
71+
72+
Example:
73+
74+
```python
75+
>>> from transformers import InternVLVisionConfig, InternVLVisionModel
76+
77+
>>> # Initializing a InternVLVisionModel OpenGVLab/InternVL3-1B-hf style configuration
78+
>>> configuration = InternVLVisionConfig()
79+
80+
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
81+
>>> model = InternVLVisionModel(configuration)
82+
83+
>>> # Accessing the model configuration
84+
>>> configuration = model.config
85+
```"""
86+
87+
model_type = "internvl_vision"
88+
base_config_key = "vision_config"
89+
90+
def __init__(
91+
self,
92+
hidden_size=1024,
93+
num_hidden_layers=24,
94+
num_attention_heads=16,
95+
attention_bias=False,
96+
use_qk_norm=False,
97+
intermediate_size=4096,
98+
hidden_act="gelu",
99+
hidden_dropout_prob=0.0,
100+
attention_dropout=0.0,
101+
projection_dropout=0.0,
102+
initializer_range=0.02,
103+
norm_type="layer_norm",
104+
layer_norm_eps=1e-06,
105+
image_size=[448, 448],
106+
patch_size=[14, 14],
107+
num_channels=3,
108+
use_mask_token=False,
109+
use_absolute_position_embeddings=True,
110+
layer_scale_init_value=0.1,
111+
use_mean_pooling=True,
112+
**kwargs,
113+
):
114+
super().__init__(**kwargs)
115+
116+
self.hidden_size = hidden_size
117+
self.num_hidden_layers = num_hidden_layers
118+
self.num_attention_heads = num_attention_heads
119+
self.attention_bias = attention_bias
120+
self.use_qk_norm = use_qk_norm
121+
self.intermediate_size = intermediate_size
122+
self.hidden_act = hidden_act
123+
self.hidden_dropout_prob = hidden_dropout_prob
124+
self.attention_dropout = attention_dropout
125+
self.projection_dropout = projection_dropout
126+
self.initializer_range = initializer_range
127+
self.norm_type = norm_type
128+
self.layer_norm_eps = layer_norm_eps
129+
130+
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
131+
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
132+
self.image_size = image_size
133+
self.patch_size = patch_size
134+
135+
self.num_channels = num_channels
136+
self.use_mask_token = use_mask_token
137+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
138+
self.layer_scale_init_value = layer_scale_init_value
139+
self.use_mean_pooling = use_mean_pooling
140+
141+
142+
class InternVLConfig(PretrainedConfig):
143+
r"""
144+
This is the configuration class to store the configuration of a [`InternVLForConditionalGeneration`]. It is used to instantiate a
145+
InternVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
146+
with the defaults will yield a similar configuration to that of InternVL3-1B.
147+
e.g. [OpenGVLab/InternVL3-1B-hf](https://huggingface.co/OpenGVLab/InternVL3-1B-hf)
148+
149+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
150+
documentation from [`PretrainedConfig`] for more information.
151+
152+
153+
Args:
154+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`):
155+
The config object or dictionary of the vision backbone.
156+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
157+
The config object or dictionary of the text backbone.
158+
image_token_id (`int`, *optional*, defaults to 151667):
159+
The image token index to encode the image prompt.
160+
image_seq_length (`int`, *optional*, defaults to 256):
161+
Number of image tokens to use per image patch.
162+
downsample_ratio (`float`, *optional*, defaults to 0.5):
163+
Factor by which to downsample the image.
164+
projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
165+
The non-linear activation function (function or string) in the projector.
166+
vision_feature_layer (`int`, *optional*, defaults to -1):
167+
The index of the layer to use as the image features.
168+
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
169+
The feature selection strategy used to select the vision feature from the vision backbone.
170+
Can be one of `"default"` or `"full"`.
171+
172+
```python
173+
>>> from transformers import InternVLForConditionalGeneration, InternVLConfig
174+
175+
>>> # Initializing a InternVL style configuration
176+
>>> configuration = InternVLConfig()
177+
178+
>>> # Initializing a model (with random weights) from the OpenGVLab/InternVL3-1B-hf configuration
179+
>>> model = InternVLForConditionalGeneration(configuration)
180+
181+
>>> # Accessing the model configuration
182+
>>> configuration = model.config
183+
```"""
184+
185+
model_type = "internvl"
186+
sub_configs = {"text_config": AutoConfig, "vision_config": InternVLVisionConfig}
187+
188+
def __init__(
189+
self,
190+
vision_config=None,
191+
text_config=None,
192+
image_token_id=151667,
193+
image_seq_length=256,
194+
downsample_ratio=0.5,
195+
projector_hidden_act="gelu",
196+
vision_feature_layer=-1,
197+
vision_feature_select_strategy="default",
198+
**kwargs,
199+
):
200+
self.image_token_id = image_token_id
201+
self.image_seq_length = image_seq_length
202+
self.downsample_ratio = downsample_ratio
203+
self.projector_hidden_act = projector_hidden_act
204+
self.vision_feature_layer = vision_feature_layer
205+
self.vision_feature_select_strategy = vision_feature_select_strategy
206+
207+
if isinstance(vision_config, dict):
208+
self.vision_config = InternVLVisionConfig(**vision_config)
209+
elif isinstance(vision_config, InternVLVisionConfig):
210+
self.vision_config = vision_config
211+
elif vision_config is None:
212+
self.vision_config = InternVLVisionConfig()
213+
214+
if isinstance(text_config, dict):
215+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
216+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
217+
elif text_config is None:
218+
text_config = CONFIG_MAPPING["qwen2"]()
219+
220+
self.text_config = text_config
221+
222+
super().__init__(**kwargs)
223+
224+
225+
__all__ = ["InternVLVisionConfig", "InternVLConfig"]

0 commit comments

Comments
 (0)