Skip to content

Commit 8e60afa

Browse files
[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)
Co-authored-by: DarkLight1337 <[email protected]>
1 parent b6d7392 commit 8e60afa

File tree

3 files changed

+49
-880
lines changed

3 files changed

+49
-880
lines changed

vllm/model_executor/models/idefics2_vision_model.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ def __init__(self, config: Idefics2VisionConfig):
6565
self.position_embedding = nn.Embedding(self.num_positions,
6666
self.embed_dim)
6767

68-
def forward(
69-
self,
70-
pixel_values: torch.FloatTensor,
71-
patch_attention_mask: torch.BoolTensor,
72-
) -> torch.Tensor:
68+
def forward(self,
69+
pixel_values: torch.FloatTensor,
70+
patch_attention_mask: torch.BoolTensor,
71+
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
7372
batch_size, _, max_im_h, max_im_w = pixel_values.shape
7473
patch_embeds = self.patch_embedding(pixel_values)
7574
embeddings = patch_embeds.flatten(2).transpose(1, 2)
@@ -84,8 +83,13 @@ def forward(
8483
fill_value=0)
8584

8685
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
87-
nb_patches_h = p_attn_mask[:, 0].sum()
88-
nb_patches_w = p_attn_mask[0].sum()
86+
87+
if tgt_sizes is not None:
88+
nb_patches_h = tgt_sizes[batch_idx][0]
89+
nb_patches_w = tgt_sizes[batch_idx][1]
90+
else:
91+
nb_patches_h = p_attn_mask[:, 0].sum()
92+
nb_patches_w = p_attn_mask[0].sum()
8993
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
9094
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
9195
bucket_coords_h = torch.bucketize(fractional_coords_h,
@@ -287,10 +291,12 @@ def forward(
287291
self,
288292
pixel_values,
289293
patch_attention_mask: Optional[torch.BoolTensor] = None,
290-
) -> torch.tensor:
294+
tgt_sizes: Optional[torch.IntTensor] = None,
295+
) -> torch.Tensor:
291296
hidden_states = self.embeddings(
292297
pixel_values=pixel_values,
293-
patch_attention_mask=patch_attention_mask)
298+
patch_attention_mask=patch_attention_mask,
299+
tgt_sizes=tgt_sizes)
294300
encoder_outputs = self.encoder(hidden_states)
295301
last_hidden_state = self.post_layernorm(encoder_outputs)
296302
return last_hidden_state

vllm/model_executor/models/minicpmv.py

Lines changed: 34 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,15 @@
3131
import torch.types
3232
from PIL import Image
3333
from torch import nn
34-
from torch.nn.init import trunc_normal_
3534
from transformers import PretrainedConfig
3635
from typing_extensions import NotRequired
3736

3837
from vllm.attention import AttentionMetadata
3938
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
4039
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
41-
from vllm.model_executor.layers.linear import ReplicatedLinear
4240
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4341
from vllm.model_executor.layers.quantization import QuantizationConfig
44-
from vllm.model_executor.layers.resampler import (Resampler2,
42+
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
4543
get_2d_sincos_pos_embed)
4644
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
4745
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
106104
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
107105

108106

109-
class BaseResampler(nn.Module):
110-
"""
111-
A 2D perceiver-resampler network with one cross attention layers by
112-
(grid_size**2) learnable queries and 2d sincos pos_emb
113-
Outputs:
114-
A tensor with the shape of (grid_size**2, embed_dim)
115-
"""
116-
117-
def __init__(
118-
self,
119-
num_queries: int,
120-
embed_dim: int,
121-
num_heads: int,
122-
kv_dim: Optional[int] = None,
123-
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
124-
) -> None:
125-
super().__init__()
126-
127-
self.num_queries = num_queries
128-
self.embed_dim = embed_dim
129-
self.num_heads = num_heads
130-
131-
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
132-
trunc_normal_(self.query, std=0.02)
133-
if kv_dim is not None and kv_dim != embed_dim:
134-
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
135-
else:
136-
# Maintain the same return value with ReplicatedLinear.forward
137-
self.kv_proj = lambda *args, **kwargs: (
138-
nn.Identity()(*args, **kwargs),
139-
None,
140-
)
141-
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
142-
self.ln_q = norm_layer(embed_dim)
143-
self.ln_kv = norm_layer(embed_dim)
144-
self.ln_post = norm_layer(embed_dim)
145-
self.proj = nn.Parameter(
146-
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
147-
148-
def _init_weights(self, m: nn.Module) -> None:
149-
if isinstance(m, nn.Linear):
150-
trunc_normal_(m.weight, std=0.02)
151-
if isinstance(m, nn.Linear) and m.bias is not None:
152-
nn.init.constant_(m.bias, 0)
153-
elif isinstance(m, nn.LayerNorm):
154-
nn.init.constant_(m.bias, 0)
155-
nn.init.constant_(m.weight, 1.0)
156-
157-
def _repeat(self, query, N: int):
158-
return query.unsqueeze(1).repeat(1, N, 1)
159-
160-
161107
class Resampler2_5(BaseResampler):
162108

163109
def __init__(
@@ -869,7 +815,35 @@ def is_default_weight_loading(self, name: str) -> bool:
869815
return "resampler" in name
870816

871817

872-
class MiniCPMV2_6(MiniCPMVBaseModel):
818+
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
819+
packed_modules_mapping = {
820+
"qkv_proj": [
821+
"q_proj",
822+
"k_proj",
823+
"v_proj",
824+
],
825+
"gate_up_proj": [
826+
"gate_proj",
827+
"up_proj",
828+
],
829+
}
830+
# LoRA specific attributes
831+
supported_lora_modules = [
832+
# vision encoder
833+
"fc1",
834+
"fc2",
835+
"out_proj",
836+
# language model
837+
"qkv_proj", # same name with vision encoder
838+
"o_proj",
839+
"gate_up_proj",
840+
"down_proj",
841+
# resampler
842+
"kv_proj",
843+
]
844+
845+
embedding_modules = {}
846+
embedding_padding_modules = []
873847

874848
def __init__(
875849
self,
@@ -894,15 +868,8 @@ def init_llm(
894868
name="model")
895869

896870
def init_vision_module(self) -> nn.Module:
897-
# A custom version of SiglipVisionTransformer, won't work with TP
898-
from vllm.model_executor.models.na_vit import SiglipVisionTransformer
899871

900-
if self.config._attn_implementation == "flash_attention_2":
901-
self.config.vision_config._attn_implementation = "flash_attention_2"
902-
else:
903-
# not support sdpa
904-
self.config.vision_config._attn_implementation = "eager"
905-
model = SiglipVisionTransformer(self.config.vision_config)
872+
model = Idefics2VisionTransformer(self.config.vision_config)
906873
if self.config.drop_vision_last_layer:
907874
model.encoder.layers = model.encoder.layers[:-1]
908875
return model
@@ -928,7 +895,7 @@ def get_vision_embedding(
928895
pixel_values,
929896
patch_attention_mask=patch_attn_mask,
930897
tgt_sizes=tgt_sizes,
931-
).last_hidden_state
898+
)
932899
return vision_embedding
933900

934901
def get_vision_hidden_states(
@@ -960,12 +927,12 @@ def get_vision_hidden_states(
960927
all_pixel_values.type(dtype),
961928
patch_attention_mask=patch_attn_mask,
962929
tgt_sizes=tgt_sizes,
963-
).last_hidden_state
930+
)
964931

965932
return self.resampler(vision_embedding, tgt_sizes)
966933

967934
def is_default_weight_loading(self, name: str) -> bool:
968-
return "resampler" in name or "vpm" in name
935+
return "resampler" in name
969936

970937

971938
_SUPPORT_VERSION = {

0 commit comments

Comments
 (0)