Skip to content

Add InternVL (2.5 MPO) #35968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
9f14e29
initial commit
yonigozlan Jan 29, 2025
18a2907
add convert internvl
yonigozlan Jan 31, 2025
bb754e9
add first end-to-end working internvl
yonigozlan Jan 31, 2025
ce881fa
nit prompt and image proc
yonigozlan Feb 3, 2025
c8bc2e5
add working chat template
yonigozlan Feb 4, 2025
a8a6142
add conversion llama-based models
yonigozlan Feb 9, 2025
aa6b6fa
add tests
yonigozlan Feb 10, 2025
72a5482
pass all tests
yonigozlan Feb 11, 2025
7c4de89
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Feb 11, 2025
f842255
fix isort
yonigozlan Feb 11, 2025
a275a1b
fix modular after main merge
yonigozlan Feb 11, 2025
747ec09
add video processing for internvl
yonigozlan Feb 11, 2025
e747692
add support for interlaced images and videos
yonigozlan Feb 12, 2025
3005a9f
Remove processing and config from modular, add more tests
yonigozlan Feb 12, 2025
d7f5d5f
add llama model tests
yonigozlan Feb 12, 2025
bc13ecb
Modify processor for compatibility with refactored got ocr image proc…
yonigozlan Feb 14, 2025
63b981a
add comments in processor
yonigozlan Feb 14, 2025
9dce4a7
Add docs and nits
yonigozlan Feb 18, 2025
8249051
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Feb 20, 2025
9e93a45
change video processing to use custom sample_indices_fn
yonigozlan Feb 20, 2025
86d9049
rebase and fix tests
yonigozlan Feb 20, 2025
329dc54
add processor tests
yonigozlan Feb 20, 2025
120ae64
Add changes Raushan review
yonigozlan Feb 24, 2025
42848e2
Use the new attention interface for the vision model
yonigozlan Feb 24, 2025
5542660
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Feb 24, 2025
5ea812b
nits
yonigozlan Mar 7, 2025
8330584
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Mar 7, 2025
50fb839
add support for custom video_load_backend
yonigozlan Mar 7, 2025
4638c15
remove mention to InternVLTokenizer
yonigozlan Mar 7, 2025
95bd18f
refactor vision model to simplify logic
yonigozlan Mar 25, 2025
a091d20
refactor processor for better readibility
yonigozlan Mar 25, 2025
c224b04
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Mar 25, 2025
f508160
fix copies
yonigozlan Mar 25, 2025
82fb703
fix require av processor test
yonigozlan Mar 25, 2025
f77c1ea
Merge branch 'main' into add-intern-vl
yonigozlan Mar 25, 2025
066ebeb
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 14, 2025
9f24691
refactor internVL vision
yonigozlan Apr 14, 2025
b9561cb
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 14, 2025
b10d88c
Update processor and fix processing tests
yonigozlan Apr 15, 2025
2a9f873
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 15, 2025
7a87bd0
fix docstring
yonigozlan Apr 15, 2025
4c21e8e
update convert_weights for internvl3
yonigozlan Apr 15, 2025
b3f61ba
Merge branch 'main' into add-intern-vl
yonigozlan Apr 15, 2025
a0b3e61
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 16, 2025
85f7579
change image processor to fast by default
yonigozlan Apr 16, 2025
c677d17
remove do_center_crop=True in convert_weights
yonigozlan Apr 16, 2025
bdb91dc
force use_cache to True
yonigozlan Apr 16, 2025
ac5b7fd
push_to_hub before reloading
yonigozlan Apr 16, 2025
31529a4
fix internVLVision for larger models
yonigozlan Apr 16, 2025
50b05a9
update convert weight for qk norm
yonigozlan Apr 16, 2025
b4857b6
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 16, 2025
8420c3d
fix convert_weights
yonigozlan Apr 17, 2025
57cbbd7
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 17, 2025
189d315
Merge branch 'add-intern-vl' of https://github.com/yonigozlan/transfo…
yonigozlan Apr 17, 2025
0501394
fix eos_token_id in convert
yonigozlan Apr 17, 2025
b4c13d5
update docs and integration tests
yonigozlan Apr 17, 2025
d536a23
Merge branch 'main' into add-intern-vl
yonigozlan Apr 17, 2025
8c70ded
Merge branch 'main' into add-intern-vl
yonigozlan Apr 17, 2025
1d2f943
make modifs after review
yonigozlan Apr 18, 2025
300fdaa
Merge branch 'add-intern-vl' of https://github.com/yonigozlan/transfo…
yonigozlan Apr 18, 2025
549085e
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 18, 2025
3456dc7
fix wrong k_norm and reduce modular
yonigozlan Apr 18, 2025
135ab88
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 18, 2025
e3ec223
change image_token_index to image_token_id
yonigozlan Apr 18, 2025
14beed0
change checkpoint to OpenGVLab org
yonigozlan Apr 18, 2025
496e3c8
last nits
yonigozlan Apr 18, 2025
3ba730a
explicitely del self.num_key_value_groups
yonigozlan Apr 18, 2025
231e9c8
Merge remote-tracking branch 'upstream/main' into add-intern-vl
yonigozlan Apr 18, 2025
015d3b2
add extra special tokens
yonigozlan Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,8 @@
title: InstructBLIP
- local: model_doc/instructblipvideo
title: InstructBlipVideo
- local: model_doc/internvl
title: InternVL
- local: model_doc/janus
title: Janus
- local: model_doc/kosmos-2
Expand Down
349 changes: 349 additions & 0 deletions docs/source/en/model_doc/internvl.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_m

### Benchmarks

FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.
FlashAttention2 speeds up inference considerably especially for inputs with long sequences. However, since FlashAttention2 doesn't support computing attention scores with padding tokens, you must manually pad and unpad the attention scores for batched inference if a sequence contains padding tokens. The downside is batched generation is slower with padding tokens.

<hfoptions id="padded">
<hfoption id="short sequence length">
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import Callable, Optional, Union
from urllib.parse import urlparse

import numpy as np
Expand Down Expand Up @@ -77,9 +77,8 @@
pil_torch_interpolation_mapping = {}


if TYPE_CHECKING:
if is_torch_available():
import torch
if is_torch_available():
import torch


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -162,6 +161,15 @@ def is_valid_list_of_images(images: list):
return images and all(is_valid_image(image) for image in images)


def concatenate_list(input_list):
if isinstance(input_list[0], list):
return [item for sublist in input_list for item in sublist]
elif isinstance(input_list[0], np.ndarray):
return np.concatenate(input_list, axis=0)
elif isinstance(input_list[0], torch.Tensor):
return torch.cat(input_list, dim=0)


def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .informer import *
from .instructblip import *
from .instructblipvideo import *
from .internvl import *
from .jamba import *
from .janus import *
from .jetmoe import *
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
("instructblipvideo", "InstructBlipVideoConfig"),
("internvl", "InternVLConfig"),
("internvl_vision", "InternVLVisionConfig"),
("jamba", "JambaConfig"),
("janus", "JanusConfig"),
("jetmoe", "JetMoeConfig"),
Expand Down Expand Up @@ -519,6 +521,8 @@
("informer", "Informer"),
("instructblip", "InstructBLIP"),
("instructblipvideo", "InstructBlipVideo"),
("internvl", "InternVL"),
("internvl_vision", "InternVLVision"),
("jamba", "Jamba"),
("janus", "Janus"),
("jetmoe", "JetMoe"),
Expand Down Expand Up @@ -797,6 +801,7 @@
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("internvl_vision", "internvl"),
("qwen2_5_vl_text", "qwen2_5_vl"),
("qwen2_vl_text", "qwen2_vl"),
("sam_vision_model", "sam"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("internvl_vision", "InternVLVisionModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
("jetmoe", "JetMoeModel"),
Expand Down Expand Up @@ -862,6 +863,7 @@
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("internvl", "InternVLForConditionalGeneration"),
("janus", "JanusForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
("idefics3", "Idefics3Processor"),
("instructblip", "InstructBlipProcessor"),
("instructblipvideo", "InstructBlipVideoProcessor"),
("internvl", "InternVLProcessor"),
("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
(
"jamba",
(
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/internvl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_internvl import *
from .modeling_internvl import *
from .processing_internvl import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
225 changes: 225 additions & 0 deletions src/transformers/models/internvl/configuration_internvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig


class InternVLVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVLVisionModel`]. It is used to instantiate an InternVLVisionModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
a similar configuration to that of the InternVL2_5-1B-MPO.
e.g. [yonigozlan/InternVL2_5-1B-MPO-hf](https://huggingface.co/yonigozlan/InternVL2_5-1B-MPO-hf)

Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries, keys and values.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to apply normalization to the queries and keys before the attention operation.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for attention weights.
projection_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for the projection layer.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use in the encoder. Can be `"layer_norm"` or `"rms_norm"`.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
image_size (`int` or `list[int]`, *optional*, defaults to `[448, 448]`):
The size (resolution) of each image.
patch_size (`int` or `list[int]`, *optional*, defaults to `[14, 14]`):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
use_mask_token (`bool`, *optional*, defaults to `False`):
Whether to use a mask token for masked image modeling.
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
Whether to use BERT-style absolute position embeddings.
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
use_mean_pooling (`bool`, *optional*, defaults to `True`):
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
CLS token, before applying the classification head.

Example:

```python
>>> from transformers import InternVLVisionConfig, InternVLVisionModel

>>> # Initializing a InternVLVisionModel yonigozlan/InternVL2_5-1B-MPO-hf style configuration
>>> configuration = InternVLVisionConfig()

>>> # Initializing a model (with random weights) from the yonigozlan/InternVL2_5-1B-MPO-hf configuration
>>> model = InternVLVisionModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "internvl_vision"
base_config_key = "vision_config"

def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
attention_bias=False,
use_qk_norm=False,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_dropout=0.0,
projection_dropout=0.0,
initializer_range=0.02,
norm_type="layer_norm",
layer_norm_eps=1e-06,
image_size=[448, 448],
patch_size=[14, 14],
num_channels=3,
use_mask_token=False,
use_absolute_position_embeddings=True,
layer_scale_init_value=0.1,
use_mean_pooling=True,
**kwargs,
):
super().__init__(**kwargs)

self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_bias = attention_bias
self.use_qk_norm = use_qk_norm
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_dropout = attention_dropout
self.projection_dropout = projection_dropout
self.initializer_range = initializer_range
self.norm_type = norm_type
self.layer_norm_eps = layer_norm_eps

image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
self.image_size = image_size
self.patch_size = patch_size

self.num_channels = num_channels
self.use_mask_token = use_mask_token
self.use_absolute_position_embeddings = use_absolute_position_embeddings
self.layer_scale_init_value = layer_scale_init_value
self.use_mean_pooling = use_mean_pooling


class InternVLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVLForConditionalGeneration`]. It is used to instantiate a
InternVL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of InternVL2_5-1B-MPO.
e.g. [yonigozlan/InternVL2_5-1B-MPO-hf](https://huggingface.co/yonigozlan/InternVL2_5-1B-MPO-hf)

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `InternVisonConfig`):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
The config object or dictionary of the text backbone.
image_token_id (`int`, *optional*, defaults to 151667):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 256):
Number of image tokens to use per image patch.
downsample_ratio (`float`, *optional*, defaults to 0.5):
Factor by which to downsample the image.
projector_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the projector.
vision_feature_layer (`int`, *optional*, defaults to -1):
The index of the layer to use as the image features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.

```python
>>> from transformers import InternVLForConditionalGeneration, InternVLConfig

>>> # Initializing a InternVL style configuration
>>> configuration = InternVLConfig()

>>> # Initializing a model (with random weights) from the yonigozlan/InternVL2_5-1B-MPO-hf configuration
>>> model = InternVLForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "internvl"
sub_configs = {"text_config": AutoConfig, "vision_config": InternVLVisionConfig}

def __init__(
self,
vision_config=None,
text_config=None,
image_token_id=151667,
image_seq_length=256,
downsample_ratio=0.5,
projector_hidden_act="gelu",
vision_feature_layer=-1,
vision_feature_select_strategy="default",
**kwargs,
):
self.image_token_id = image_token_id
self.image_seq_length = image_seq_length
self.downsample_ratio = downsample_ratio
self.projector_hidden_act = projector_hidden_act
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy

if isinstance(vision_config, dict):
self.vision_config = InternVLVisionConfig(**vision_config)
elif isinstance(vision_config, InternVLVisionConfig):
self.vision_config = vision_config
elif vision_config is None:
self.vision_config = InternVLVisionConfig()

if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["qwen2"]()

self.text_config = text_config

super().__init__(**kwargs)


__all__ = ["InternVLVisionConfig", "InternVLConfig"]
Loading