1
- from typing import (Iterable , List , Mapping , Optional , Set , Tuple , TypedDict ,
2
- Union )
1
+ from typing import (Callable , Iterable , List , Mapping , Optional , Set , Tuple ,
2
+ TypedDict , Union )
3
3
4
4
import torch
5
5
import torch .nn as nn
9
9
from vllm .attention import AttentionMetadata
10
10
from vllm .config import CacheConfig , QuantizationConfig , VllmConfig
11
11
from vllm .distributed import get_tensor_model_parallel_rank
12
- from vllm .inputs import InputContext
13
12
from vllm .model_executor .layers .activation import get_act_fn
14
13
from vllm .model_executor .layers .fused_moe import FusedMoE
15
14
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
@@ -87,8 +86,8 @@ def __init__(
87
86
def forward (
88
87
self ,
89
88
pixel_values : torch .Tensor ,
90
- pixel_mask : Optional [torch .BoolTensor ] = None ,
91
- ) -> Tuple [torch .Tensor , Optional [torch .BoolTensor ]]:
89
+ pixel_mask : Optional [torch .Tensor ] = None ,
90
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
92
91
patch_attention_mask = self ._create_patch_attention_mask (pixel_mask )
93
92
94
93
vit_oup = self .vision_model (
@@ -100,7 +99,8 @@ def forward(
100
99
101
100
return vit_oup , image_atts
102
101
103
- def _create_patch_attention_mask (self , pixel_mask ):
102
+ def _create_patch_attention_mask (
103
+ self , pixel_mask : Optional [torch .Tensor ]) -> torch .Tensor :
104
104
if pixel_mask is None :
105
105
return None
106
106
@@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask):
115
115
)
116
116
return (patches_subgrid .sum (dim = (- 1 , - 2 )) > 0 ).bool ()
117
117
118
- def _create_image_attention_mask (self , patch_attention_mask ):
118
+ def _create_image_attention_mask (
119
+ self , patch_attention_mask : torch .Tensor ) -> torch .Tensor :
119
120
if patch_attention_mask is None :
120
121
return None
121
122
@@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask):
125
126
126
127
class FFN (nn .Module ):
127
128
128
- def __init__ (self , embed_dim , ff_dim , output_dim ) :
129
+ def __init__ (self , embed_dim : int , ff_dim : int , output_dim : int ) -> None :
129
130
super ().__init__ ()
130
131
self .linear_in = ColumnParallelLinear (embed_dim , ff_dim , bias = False )
131
132
self .linear_out = RowParallelLinear (ff_dim , output_dim , bias = False )
132
133
self .act = get_act_fn ("gelu_new" )
133
134
134
- def forward (self , hidden_states ) :
135
+ def forward (self , hidden_states : torch . Tensor ) -> torch . Tensor :
135
136
hidden_states , _ = self .linear_in (hidden_states )
136
137
hidden_states = self .act (hidden_states )
137
138
hidden_states , _ = self .linear_out (hidden_states )
@@ -140,7 +141,7 @@ def forward(self, hidden_states):
140
141
141
142
class CrossAttention (nn .Module ):
142
143
143
- def __init__ (self , kv_dim , embed_dim , num_heads , drop_out_rate = 0 ) :
144
+ def __init__ (self , kv_dim : int , embed_dim : int , num_heads : int ) -> None :
144
145
super ().__init__ ()
145
146
self .num_heads = num_heads
146
147
self .q_proj = nn .Linear (embed_dim , embed_dim , bias = False )
@@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
149
150
150
151
self .multihead_attn = nn .MultiheadAttention (embed_dim , num_heads )
151
152
self .linear = nn .Linear (embed_dim , embed_dim )
152
- self .dropout = nn .Dropout (drop_out_rate )
153
153
154
154
self .layer_norm = nn .LayerNorm (embed_dim )
155
155
self .ln_kv = nn .LayerNorm (kv_dim )
156
156
157
- def forward (self , x , hidden_states , attn_mask = None , add_residual = False ):
157
+ def forward (
158
+ self ,
159
+ x : torch .Tensor ,
160
+ hidden_states : torch .Tensor ,
161
+ attn_mask : Optional [torch .Tensor ] = None ,
162
+ ) -> torch .Tensor :
158
163
normed_hidden_states = self .layer_norm (hidden_states )
159
164
query = self .q_proj (normed_hidden_states ).permute (1 , 0 , 2 )
160
165
@@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
169
174
170
175
attn_output = attn_output .permute (1 , 0 , 2 )
171
176
172
- if add_residual :
173
- attn_output = hidden_states + self .dropout (
174
- self .linear (attn_output ))
175
- else :
176
- attn_output = self .dropout (self .linear (attn_output ))
177
+ attn_output = self .linear (attn_output )
177
178
178
179
return attn_output
179
180
@@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
201
202
202
203
def __init__ (
203
204
self ,
204
- patch_to_query_dict ,
205
- embed_dim ,
206
- num_heads ,
207
- kv_dim ,
208
- ff_dim ,
209
- output_dim ,
210
- norm_layer = nn .LayerNorm ,
211
- ):
205
+ patch_to_query_dict : dict [ int , int ] ,
206
+ embed_dim : int ,
207
+ num_heads : int ,
208
+ kv_dim : int ,
209
+ ff_dim : int ,
210
+ output_dim : int ,
211
+ norm_layer : Callable [[ int ], nn . Module ] = nn .LayerNorm ,
212
+ ) -> None :
212
213
super ().__init__ ()
213
214
self .patch_to_query_dict = patch_to_query_dict
214
215
self .embed_dim = embed_dim
@@ -224,7 +225,11 @@ def __init__(
224
225
self .ln_ffn = norm_layer (embed_dim )
225
226
self .ffn = FFN (embed_dim , ff_dim , output_dim )
226
227
227
- def forward (self , x , attn_mask = None ):
228
+ def forward (
229
+ self ,
230
+ x : torch .Tensor ,
231
+ attn_mask : Optional [torch .Tensor ] = None ,
232
+ ) -> torch .Tensor :
228
233
bs = x .shape [0 ]
229
234
queries = self .query .unsqueeze (0 ).repeat (bs , 1 , 1 )
230
235
@@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig):
442
447
)
443
448
444
449
445
- def get_max_aria_image_tokens (ctx : InputContext ):
446
- hf_config = ctx .get_hf_config ()
447
- return max (hf_config .projector_patch_to_query_dict .values ())
450
+ class AriaMultiModalProcessor (BaseMultiModalProcessor ):
451
+
452
+ def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
453
+ return {"image" : None }
448
454
455
+ def _get_num_image_tokens (self ) -> int :
456
+ hf_config = self .ctx .get_hf_config ()
457
+ return max (hf_config .projector_patch_to_query_dict .values ())
449
458
450
- class AriaMultiModalProcessor (BaseMultiModalProcessor ):
459
+ def get_mm_max_tokens_per_item (self ) -> Mapping [str , int ]:
460
+ return {"image" : self ._get_num_image_tokens ()}
451
461
452
462
def _get_mm_fields_config (
453
463
self ,
@@ -468,13 +478,13 @@ def _get_prompt_replacements(
468
478
hf_config = self .ctx .get_hf_config ()
469
479
image_token_id = hf_config .image_token_index
470
480
471
- max_image_tokens = get_max_aria_image_tokens ( self .ctx )
481
+ num_image_tokens = self ._get_num_image_tokens ( )
472
482
473
483
return [
474
484
PromptReplacement (
475
485
modality = "image" ,
476
486
target = [image_token_id ],
477
- replacement = [image_token_id ] * max_image_tokens ,
487
+ replacement = [image_token_id ] * num_image_tokens ,
478
488
)
479
489
]
480
490
@@ -504,7 +514,6 @@ def _get_dummy_mm_inputs(
504
514
)
505
515
506
516
507
- @MULTIMODAL_REGISTRY .register_max_image_tokens (get_max_aria_image_tokens )
508
517
@MULTIMODAL_REGISTRY .register_processor (AriaMultiModalProcessor )
509
518
class AriaForConditionalGeneration (nn .Module , SupportsMultiModal ):
510
519
"""
0 commit comments