Skip to content

Commit d4d77ca

Browse files
ywang96imkeroDarkLight1337
authored andcommitted
[V1] Add V1 support of Qwen2-VL (vllm-project#12128)
Signed-off-by: Roger Wang <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: imkero <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent d298226 commit d4d77ca

File tree

9 files changed

+292
-85
lines changed

9 files changed

+292
-85
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
754754
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
755755
- ✅︎
756756
- ✅︎
757-
-
757+
- ✅︎
758758
* - `UltravoxModel`
759759
- Ultravox
760760
- T + A<sup>E+</sup>

tests/models/decoder_only/vision_language/test_qwen2_vl.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def batch_make_image_embeddings(
105105
pixel_values = preprocess_result["pixel_values"]
106106
image_grid_thw = preprocess_result["image_grid_thw"]
107107

108-
# pixel values to embeddinds & grid_thws
108+
# pixel values to embeddings & grid_thws
109109
with torch.no_grad():
110110
visual = llm.llm_engine.model_executor.driver_worker. \
111111
model_runner.model.visual
@@ -124,11 +124,10 @@ def batch_make_image_embeddings(
124124
for image_batch in image_batches_:
125125
cur_batch_image_count = len(image_batch)
126126
merge_size = image_processor.merge_size
127-
cur_batch_embed_len = sum([
128-
grid_thw.prod() // merge_size // merge_size
127+
cur_batch_embed_len = sum(
128+
grid_thw.prod(-1) // merge_size // merge_size
129129
for grid_thw in image_grid_thw[image_counter:image_counter +
130-
cur_batch_image_count]
131-
])
130+
cur_batch_image_count])
132131

133132
result.append({
134133
"image_embeds":
@@ -187,7 +186,7 @@ def batch_make_video_embeddings(
187186
pixel_values = preprocess_result["pixel_values_videos"]
188187
video_grid_thw = preprocess_result["video_grid_thw"]
189188

190-
# pixel values to embeddinds & grid_thws
189+
# pixel values to embeddings & grid_thws
191190
with torch.no_grad():
192191
visual = llm.llm_engine.model_executor.driver_worker.\
193192
model_runner.model.visual
@@ -206,11 +205,10 @@ def batch_make_video_embeddings(
206205
for video_batch in video_batches_:
207206
cur_batch_video_count = len(video_batch)
208207
merge_size = image_processor.merge_size
209-
cur_batch_embed_len = sum([
210-
grid_thw.prod() // merge_size // merge_size
208+
cur_batch_embed_len = sum(
209+
grid_thw.prod(-1) // merge_size // merge_size
211210
for grid_thw in video_grid_thw[video_counter:video_counter +
212-
cur_batch_video_count]
213-
])
211+
cur_batch_video_count])
214212

215213
result.append({
216214
"video_embeds":

vllm/compilation/decorators.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
7676
During runtime, when we actually mark dimensions of tensors,
7777
it depends on the value of arguments:
7878
79-
- if it is a single integer, the corresponding dimension of the argument
80-
will be marked as dynamic.
79+
- if it is a single integer (can be negative), the corresponding dimension
80+
of the argument will be marked as dynamic.
8181
- if it is `None`, ignored.
8282
- if it is `IntermediateTensors`, all the tensors in the intermediate
8383
tensors will be marked as dynamic.
@@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
177177
for k, dims in dynamic_arg_dims.items():
178178
arg = bound_args.arguments.get(k)
179179
if arg is not None:
180+
dims = [dims] if isinstance(dims, int) else dims
180181
if isinstance(arg, torch.Tensor):
182+
# In case dims is specified with negative indexing
183+
dims = [
184+
arg.ndim + dim if dim < 0 else dim for dim in dims
185+
]
181186
torch._dynamo.mark_dynamic(arg, dims)
182187
elif isinstance(arg, IntermediateTensors):
183188
for tensor in arg.tensors.values():
189+
# In case dims is specified with negative indexing
190+
dims = [
191+
tensor.ndim + dim if dim < 0 else dim
192+
for dim in dims
193+
]
184194
torch._dynamo.mark_dynamic(tensor, dims)
185195
else:
186196
raise ValueError(

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,37 @@ def get_input_positions(
841841
) -> Tuple[List[List[int]], int]:
842842
"""Get mrope input positions and delta value."""
843843

844+
llm_positions, mrope_position_delta = \
845+
MRotaryEmbedding.get_input_positions_tensor(
846+
input_tokens,
847+
image_grid_thw,
848+
video_grid_thw,
849+
image_token_id,
850+
video_token_id,
851+
vision_start_token_id,
852+
vision_end_token_id,
853+
spatial_merge_size,
854+
context_len,
855+
seq_len,
856+
)
857+
858+
return llm_positions.tolist(), mrope_position_delta
859+
860+
@staticmethod
861+
def get_input_positions_tensor(
862+
input_tokens: List[int],
863+
image_grid_thw: Union[List[List[int]], torch.Tensor],
864+
video_grid_thw: Union[List[List[int]], torch.Tensor],
865+
image_token_id: int,
866+
video_token_id: int,
867+
vision_start_token_id: int,
868+
vision_end_token_id: int,
869+
spatial_merge_size: int,
870+
context_len: int = 0,
871+
seq_len: Optional[int] = None,
872+
) -> Tuple[torch.Tensor, int]:
873+
"""Get mrope input positions and delta value."""
874+
844875
if isinstance(image_grid_thw, torch.Tensor):
845876
image_grid_thw = image_grid_thw.tolist()
846877
if isinstance(video_grid_thw, torch.Tensor):
@@ -916,7 +947,7 @@ def get_input_positions(
916947
len(input_tokens)).item()
917948
llm_positions = llm_positions[:, context_len:seq_len]
918949

919-
return llm_positions.tolist(), mrope_position_delta
950+
return llm_positions, mrope_position_delta
920951

921952
@staticmethod
922953
def get_next_input_positions(
@@ -930,6 +961,17 @@ def get_next_input_positions(
930961
seq_len + mrope_position_delta)) for _ in range(3)
931962
]
932963

964+
@staticmethod
965+
def get_next_input_positions_tensor(
966+
mrope_position_delta: int,
967+
context_len: int,
968+
seq_len: int,
969+
) -> torch.Tensor:
970+
return torch.arange(
971+
mrope_position_delta + context_len,
972+
mrope_position_delta + seq_len,
973+
).expand(3, -1)
974+
933975

934976
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
935977

vllm/model_executor/models/llava_onevision.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
554554
# Preserve the order of modalities if there are multiple of them
555555
# from the order of kwargs.
556556
for input_key in kwargs:
557-
if input_key == "pixel_values" and "images" not in modalities:
557+
if input_key in ("pixel_values",
558+
"image_embeds") and "images" not in modalities:
558559
modalities["images"] = self._parse_and_validate_image_input(
559560
**kwargs)
560-
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
561+
if input_key in ("pixel_values_videos",
562+
"video_embeds") and "videos" not in modalities:
561563
modalities["videos"] = self._parse_and_validate_video_input(
562564
**kwargs)
563565

vllm/model_executor/models/qwen2.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,15 @@ def forward(
256256
return hidden_states, residual
257257

258258

259-
@support_torch_compile
259+
@support_torch_compile(
260+
dynamic_arg_dims={
261+
"input_ids": 0,
262+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
263+
# otherwise (seq_len, ).
264+
"positions": -1,
265+
"intermediate_tensors": 0,
266+
"inputs_embeds": 0,
267+
})
260268
class Qwen2Model(nn.Module):
261269

262270
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

0 commit comments

Comments
 (0)