Skip to content

[V1][Model] Add V1 support for Qwen2-VL #11668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
logger = init_logger(__name__)

_T = TypeVar("_T", bound=type[nn.Module])
DimIndexes = Union[int, List[int]]
DimIndexesSelector = Callable[[torch.Tensor], DimIndexes]


@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes,
DimIndexesSelector]]],
) -> Callable[[_T], _T]:
...

Expand All @@ -36,7 +39,8 @@ def support_torch_compile(cls: _T) -> _T:
def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes,
DimIndexesSelector]]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
Expand Down Expand Up @@ -78,6 +82,9 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):

- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is a function returns a single integer, it will be called with
the tensor as argument, and the returned dimension will be marked
as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
Expand Down Expand Up @@ -129,7 +136,7 @@ def cls_decorator_helper(cls: _T) -> _T:

def _support_torch_compile(
cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
dynamic_arg_dims: Dict[str, Union[DimIndexes, DimIndexesSelector]],
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
Expand Down Expand Up @@ -178,10 +185,12 @@ def __call__(self, *args, **kwargs):
arg = bound_args.arguments.get(k)
if arg is not None:
if isinstance(arg, torch.Tensor):
torch._dynamo.mark_dynamic(arg, dims)
dims_ = dims(arg) if callable(dims) else dims
torch._dynamo.mark_dynamic(arg, dims_)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
torch._dynamo.mark_dynamic(tensor, dims)
dims_ = dims(tensor) if callable(dims) else dims
torch._dynamo.mark_dynamic(tensor, dims_)
else:
raise ValueError(
"Unsupported dynamic dimensions"
Expand Down
44 changes: 43 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,37 @@ def get_input_positions(
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
input_tokens,
image_grid_thw,
video_grid_thw,
image_token_id,
video_token_id,
vision_start_token_id,
vision_end_token_id,
spatial_merge_size,
context_len,
seq_len,
)

return llm_positions.tolist(), mrope_position_delta

@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""

if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
Expand Down Expand Up @@ -916,7 +947,7 @@ def get_input_positions(
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions.tolist(), mrope_position_delta
return llm_positions, mrope_position_delta

@staticmethod
def get_next_input_positions(
Expand All @@ -930,6 +961,17 @@ def get_next_input_positions(
seq_len + mrope_position_delta)) for _ in range(3)
]

@staticmethod
def get_next_input_positions_tensor(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.arange(
mrope_position_delta + context_len,
mrope_position_delta + seq_len,
).expand(3, -1)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}

Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,14 @@ def forward(
return hidden_states, residual


@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# dim 1 for mrope in shape (3, seq_len), else dim 0 in shape (seq_len, )
"positions": lambda tensor: tensor.ndim - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does -1 work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value here will be passthrough
to pytorch's impl torch._dynamo.mark_dynamic(tensor, dim), and it seems to assume that dim is a non-negative integer.

https://github.com/pytorch/pytorch/blob/95b41d2aa43c606d65e127d4825c08baf9fcacd9/torch/_dynamo/decorators.py#L464

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do the conversion here:

for k, dims in dynamic_arg_dims.items():

iterate over the dims , and conver -1 to tensor.ndim - 1

"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen2Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
23 changes: 13 additions & 10 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def _parse_video_data(
return super()._parse_video_data(data)



class Qwen2VLProcessingInfo(BaseProcessingInfo):

def get_hf_config(self):
Expand Down Expand Up @@ -935,6 +936,7 @@ def get_dummy_processor_inputs(

class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
_placeholder_map: Optional[dict[str, list[int]]] = None

def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
Expand All @@ -949,19 +951,23 @@ def _get_prompt_replacements(
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)

# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder = {
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
if not self._placeholder_map:
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
encode_fn = hf_processor.tokenizer.encode
self._placeholder_map = {
"image": encode_fn(hf_processor.image_token),
"video": encode_fn(hf_processor.video_token),
}
placeholder = self._placeholder_map

Comment on lines +954 to +963
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can set this at initialization time.

merge_length = image_processor.merge_size**2

def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)

num_tokens = grid_thw.prod() // merge_length
num_tokens = grid_thw.prod().item() // merge_length
return placeholder[modality] * num_tokens

return [
Expand Down Expand Up @@ -1057,11 +1063,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2VLConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"

self.config = config
self.multimodal_config = multimodal_config
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class CachedRequestState:
num_computed_tokens: int
output_token_ids: List[int]

mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None

@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
Expand Down
Loading
Loading