Skip to content

Add InternVL (2.5 MPO) #35968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Apr 18, 2025
Merged

Add InternVL (2.5 MPO) #35968

merged 69 commits into from
Apr 18, 2025

Conversation

yonigozlan
Copy link
Member

What does this PR do?

Add InternVL to Transformers.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yonigozlan yonigozlan changed the title [WIP] Add InternVL (2.5 MPO) Add InternVL (2.5 MPO) Feb 20, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALright, very nice!! 🤗 Maybe we can push a bit to use modular more, especially on the Vision part though? Let me know if it could work! Otherwise it's in a very good state so we can merge it very soon 👌

Just about the checkpoints in the examples/docstrings/a bit everywhere, I see you used "yonigozlan/...". Should those point to the main repo instead? It's no issue in the tests, but in the docstrings etc it's best to use original checkpoints if any!

>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other Llava-like that could use more inheritance maybe? 🤗

Comment on lines 280 to 294

@slow
@require_torch_gpu
class InternVLQwen2IntegrationTest(unittest.TestCase):
def setUp(self):
self.small_model_checkpoint = "yonigozlan/InternVL3-1B-hf"
self.medium_model_checkpoint = "yonigozlan/InternVL3-2B-hf"

def tearDown(self):
cleanup(torch_device, gc_collect=True)

def test_qwen2_small_model_integration_generate(self):
processor = AutoProcessor.from_pretrained(self.small_model_checkpoint)
model = InternVLForConditionalGeneration.from_pretrained(
self.small_model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice IntegrationTests! 🤗 Did you make sure that the outputs are the same on T4 by any chance so we don't have to potentially adjust later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used A10, the CI runners have A10 no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they use both A10 and T4, but if you run them on A10 it's all good!

@yonigozlan
Copy link
Member Author

Thanks for the review @Cyrilvallez ! Made the modifs and I'm using more modular for InternVLVision.

Just about the checkpoints in the examples/docstrings/a bit everywhere, I see you used "yonigozlan/...". Should those point to the main repo instead? It's no issue in the tests, but in the docstrings etc it's best to use original checkpoints if any!

Yes of course, I'll do a "replace all" once I moved the checkpoints to the main repo :), just need to convert them again, and I'll move them once you give me the green light!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, see the final comments! Then we're all good!
You need to patch the checkpoints in the processor tests as well, as it's currently failing because it does not seem to exist on the hub

Comment on lines 191 to 192
text_config=None,
image_token_index=151667,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just decided to standardize and use id instead of index here for the tokens, so let's change it before merging! See #37573

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means I have to override a big chunk of the forward function no, since llava uses image_token_index right now... I'll change here if #37573 is merged first, otherwise i'll ping in the PR :)

Comment on lines 86 to 113
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm

# InternVLVision has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1

# Needed for flash attention
self.is_causal = False

self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()

self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not need most of that here - also the Norms are mixed up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oups thanks for catching that! And I didn't know modular could just partially override a method, that's very cool thanks!

Comment on lines +314 to +315
class InternVLVisionMLP(CLIPMLP):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I knew we had this one somewhere! 🤗

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last picky comments! 🤗

Comment on lines 87 to 88
# InternVLVision has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eager is not using it, you removed the repeat_kv already! Let's remove it!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, LGTM, nothing to add! Thanks a lot!! 🤗🤗

Last comment is related to @zucchini-nlp's comment, let's apply it before merging! But then, it's all good feel free to merge!

@Cyrilvallez
Copy link
Member

Merging! Thanks again! 🤗🤗

@Cyrilvallez Cyrilvallez merged commit a245011 into huggingface:main Apr 18, 2025
18 of 20 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* initial commit

* add convert internvl

* add first end-to-end working internvl

* nit prompt and image proc

* add working chat template

* add conversion llama-based models

* add tests

* pass all tests

* fix isort

* fix modular after main merge

* add video processing for internvl

* add support for interlaced images and videos

* Remove processing and config from modular, add more tests

* add llama model tests

* Modify processor for compatibility with refactored got ocr image processor

* add comments in processor

* Add docs and nits

* change video processing to use custom sample_indices_fn

* rebase and fix tests

* add processor tests

* Add changes Raushan review

* Use the new attention interface for the vision model

* nits

* add support for custom video_load_backend

* remove mention to InternVLTokenizer

* refactor vision model to simplify logic

* refactor processor for better readibility

* fix copies

* fix require av processor test

* refactor internVL vision

* Update processor and fix processing tests

* fix docstring

* update convert_weights for internvl3

* change image processor to fast by default

* remove do_center_crop=True in convert_weights

* force use_cache to True

* push_to_hub before reloading

* fix internVLVision for larger models

* update convert weight for qk norm

* fix convert_weights

* fix eos_token_id in convert

* update docs and integration tests

* make modifs after review

* fix wrong k_norm and reduce modular

* change image_token_index to image_token_id

* change checkpoint to OpenGVLab org

* last nits

* explicitely del self.num_key_value_groups

* add extra special tokens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants