Skip to content

[V1] Add V1 support of Qwen2-VL #12128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
-
- ✅︎
* - `UltravoxModel`
- Ultravox
- T + A<sup>E+</sup>
Expand Down
18 changes: 8 additions & 10 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
Expand All @@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
cur_batch_image_count]
])
cur_batch_image_count])

result.append({
"image_embeds":
Expand Down Expand Up @@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]

# pixel values to embeddinds & grid_thws
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
Expand All @@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
cur_batch_embed_len = sum([
grid_thw.prod() // merge_size // merge_size
cur_batch_embed_len = sum(
grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
cur_batch_video_count]
])
cur_batch_video_count])

result.append({
"video_embeds":
Expand Down
14 changes: 12 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:

- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
Expand Down Expand Up @@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [
arg.ndim + dim if dim < 0 else dim for dim in dims
]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [
tensor.ndim + dim if dim < 0 else dim
for dim in dims
]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
Expand Down
44 changes: 43 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,37 @@ def get_input_positions(
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
)

return llm_positions.tolist(), mrope_position_delta

@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""

if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
Expand Down Expand Up @@ -916,7 +947,7 @@ def get_input_positions(
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions.tolist(), mrope_position_delta
return llm_positions, mrope_position_delta

@staticmethod
def get_next_input_positions(
Expand All @@ -930,6 +961,17 @@ def get_next_input_positions(
seq_len + mrope_position_delta)) for _ in range(3)
]

@staticmethod
def get_next_input_positions_tensor(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.arange(
mrope_position_delta + context_len,
mrope_position_delta + seq_len,
).expand(3, -1)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}

Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)

Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,15 @@ def forward(
return hidden_states, residual


@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen2Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
Loading
Loading