From c08251f5d2472d9faca8eb3e7a580444bb490f39 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 16 Jan 2025 16:17:51 +0000 Subject: [PATCH 01/20] init Signed-off-by: Roger Wang --- vllm/compilation/decorators.py | 19 ++- .../model_executor/layers/rotary_embedding.py | 44 +++++- vllm/model_executor/models/qwen2.py | 32 +++-- vllm/model_executor/models/qwen2_vl.py | 104 +++++++++------ vllm/v1/worker/gpu_input_batch.py | 3 + vllm/v1/worker/gpu_model_runner.py | 126 +++++++++++++++++- 6 files changed, 261 insertions(+), 67 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 10513111ea7..7e59cf87b50 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -18,12 +18,15 @@ logger = init_logger(__name__) _T = TypeVar("_T", bound=type[nn.Module]) +DimIndexes = Union[int, List[int]] +DimIndexesSelector = Callable[[torch.Tensor], DimIndexes] @overload def support_torch_compile( *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], + dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes, + DimIndexesSelector]]], ) -> Callable[[_T], _T]: ... @@ -36,7 +39,8 @@ def support_torch_compile(cls: _T) -> _T: def support_torch_compile( cls: Optional[_T] = None, *, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, + dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes, + DimIndexesSelector]]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -78,6 +82,9 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): - if it is a single integer, the corresponding dimension of the argument will be marked as dynamic. + - if it is a function returns a single integer, it will be called with + the tensor as argument, and the returned dimension will be marked + as dynamic. - if it is `None`, ignored. - if it is `IntermediateTensors`, all the tensors in the intermediate tensors will be marked as dynamic. @@ -129,7 +136,7 @@ def cls_decorator_helper(cls: _T) -> _T: def _support_torch_compile( cls: _T, - dynamic_arg_dims: Dict[str, Union[int, List[int]]], + dynamic_arg_dims: Dict[str, Union[DimIndexes, DimIndexesSelector]], ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -178,10 +185,12 @@ def __call__(self, *args, **kwargs): arg = bound_args.arguments.get(k) if arg is not None: if isinstance(arg, torch.Tensor): - torch._dynamo.mark_dynamic(arg, dims) + dims_ = dims(arg) if callable(dims) else dims + torch._dynamo.mark_dynamic(arg, dims_) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): - torch._dynamo.mark_dynamic(tensor, dims) + dims_ = dims(tensor) if callable(dims) else dims + torch._dynamo.mark_dynamic(tensor, dims_) else: raise ValueError( "Unsupported dynamic dimensions" diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 3fcd81a3c42..d071cfe888f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -841,6 +841,37 @@ def get_input_positions( ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" + llm_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + input_tokens, + image_grid_thw, + video_grid_thw, + image_token_id, + video_token_id, + vision_start_token_id, + vision_end_token_id, + spatial_merge_size, + context_len, + seq_len, + ) + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_input_positions_tensor( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() if isinstance(video_grid_thw, torch.Tensor): @@ -916,7 +947,7 @@ def get_input_positions( len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] - return llm_positions.tolist(), mrope_position_delta + return llm_positions, mrope_position_delta @staticmethod def get_next_input_positions( @@ -930,6 +961,17 @@ def get_next_input_positions( seq_len + mrope_position_delta)) for _ in range(3) ] + @staticmethod + def get_next_input_positions_tensor( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> torch.Tensor: + return torch.arange( + mrope_position_delta + context_len, + mrope_position_delta + seq_len, + ).expand(3, -1) + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b9c259ad73c..dfbcf860373 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -256,7 +256,14 @@ def forward( return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # dim 1 for mrope in shape (3, seq_len), else dim 0 in shape (seq_len, ) + "positions": lambda tensor: tensor.ndim - 1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) class Qwen2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -279,7 +286,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.config = config - self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -365,18 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache scales for quark and - # compressed-tensors quantization - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -431,6 +425,16 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d00e5d362c8..9f444bb4215 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,7 +67,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) + init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -709,6 +709,7 @@ def _parse_video_data( return super()._parse_video_data(data) + class Qwen2VLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): @@ -935,6 +936,7 @@ def get_dummy_processor_inputs( class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): + _placeholder_map: Optional[dict[str, list[int]]] = None def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() @@ -949,19 +951,23 @@ def _get_prompt_replacements( image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) - # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has - # image_token and video_token registered - placeholder = { - "image": hf_processor.image_token, - "video": hf_processor.video_token, - } + if not self._placeholder_map: + # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has + # image_token and video_token registered + encode_fn = hf_processor.tokenizer.encode + self._placeholder_map = { + "image": encode_fn(hf_processor.image_token), + "video": encode_fn(hf_processor.video_token), + } + placeholder = self._placeholder_map + merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) - num_tokens = grid_thw.prod() // merge_length + num_tokens = grid_thw.prod().item() // merge_length return placeholder[modality] * num_tokens return [ @@ -1038,6 +1044,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", @@ -1047,11 +1063,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2VLConfig = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - assert not cache_config.enable_prefix_caching, \ - "Qwen2-VL currently does not support prefix caching" self.config = config self.multimodal_config = multimodal_config @@ -1192,8 +1205,24 @@ def _process_video_input(self, self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_input["video_grid_thw"]) + print(video_embeds.shape) return video_embeds + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key == "pixel_values" and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501 + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + def _merge_multimodal_embeddings( self, input_ids: torch.Tensor, @@ -1208,25 +1237,26 @@ def _merge_multimodal_embeddings( def get_multimodal_embeddings( self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - if image_input is None and video_input is None: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: return None - # We make a tuple of each embedding with its modality string. This is a - # temporary workaround for models to handle mixed modalities when - # get_multimodal_embeddings and get_input_embeddings are called - # separately. - # TODO(ywang96): Add support for mixed-modality inference for v1. - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - multimodal_embeddings.append((image_embeds, "image")) - if video_input is not None: - video_embeds = self._process_video_input(video_input) - multimodal_embeddings.append((video_embeds, "video")) - + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += tuple(video_embeddings) + #print(multimodal_embeddings) return multimodal_embeddings def get_input_embeddings( @@ -1237,21 +1267,9 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - for embeddings, modality in multimodal_embeddings: - if modality == "image": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.image_token_id, - ) - if modality == "video": - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - embeddings, - placeholder_token_id=self.config.video_token_id, - ) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) return inputs_embeds def forward( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22..28d8e390538 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -30,6 +30,9 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index de83640b27c..c73f68acd35 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType @@ -135,6 +136,24 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + + if self.model_config.uses_mrope: + # a permuted mrope_positions tensor satisfying the following + # properties to allow `torch.compile` work properly: + # - shape: (3, ) + # - stride: (1, ) + + self.mrope_positions = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device=self.device) + self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + self.mrope_positions = self.mrope_positions.permute((1, 0)) + self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0)) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -242,6 +261,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], ) + + if self.model_config.uses_mrope: + image_grid_thw = [] + video_grid_thw = [] + for mm_input in self.requests[req_id].mm_inputs: + if mm_input.get("image_grid_thw") is not None: + image_grid_thw.extend( + mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) + + hf_config = self.model_config.hf_config + + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + ) + req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. @@ -309,6 +356,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): arange, out=positions_np) + # Calculate M-RoPE positions. + if self.model_config.uses_mrope: + self._calc_mrope_positions(scheduler_output) + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -355,8 +406,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.model_config.uses_mrope: + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( @@ -468,6 +525,61 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + num_reqs = self.input_batch.num_reqs + for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + def _prepare_sampling( self, scheduler_output: "SchedulerOutput", @@ -614,9 +726,12 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): + positions = self.mrope_positions[:, :num_input_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=positions, kv_caches=self.kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, @@ -703,9 +818,12 @@ def _dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None with set_forward_context(None, self.vllm_config): + positions = self.mrope_positions[:, :num_tokens] \ + if self.model_config.uses_mrope \ + else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=positions, kv_caches=kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, From ed31e68dcec7a7afbf272f105adc931e82de528d Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 16 Jan 2025 18:19:42 +0000 Subject: [PATCH 02/20] fix Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2.py | 1 + vllm/model_executor/models/qwen2_vl.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 35831b0c66c..18fe078052e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -286,6 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.config = config + self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9f444bb4215..7fc6d298ec3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1205,7 +1205,6 @@ def _process_video_input(self, self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_input["video_grid_thw"]) - print(video_embeds.shape) return video_embeds def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -1256,7 +1255,7 @@ def get_multimodal_embeddings( video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) multimodal_embeddings += tuple(video_embeddings) - #print(multimodal_embeddings) + return multimodal_embeddings def get_input_embeddings( From b2bd0e3cf4667190d8368e380c9a3f67093696be Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 07:13:25 +0000 Subject: [PATCH 03/20] e2e working Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7fc6d298ec3..4afddaedfd7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -67,11 +67,15 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) from .vision import get_vit_attn_backend logger = init_logger(__name__) +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + # === Vision Inputs === # @@ -611,6 +615,11 @@ def forward( # adapter x = self.merger(x) + + # split by individual data items + sizes = grid_thw.prod( + -1) // self.spatial_merge_size // self.spatial_merge_size + x = x.split(sizes.tolist()) return x def load_weights(self, weights: Iterable[Tuple[str, @@ -875,8 +884,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - - num_frames = max(max_total_frames // max(max_videos, 1), 1) + num_frames = min(max(max_total_frames // max(max_videos, 1), 1), + _MAX_FRAMES_PER_VIDEO) # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 if num_frames > 1 and num_frames % 2 == 1: @@ -1250,11 +1259,11 @@ def get_multimodal_embeddings( if modality == "images": image_input = modalities["images"] vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(vision_embeddings) + multimodal_embeddings += vision_embeddings if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += tuple(video_embeddings) + multimodal_embeddings += video_embeddings return multimodal_embeddings From 473405996bac0fb3237ed8aa120b01eaf8e7f662 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 08:17:11 +0000 Subject: [PATCH 04/20] typing Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 4afddaedfd7..a3190556819 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -594,7 +594,7 @@ def forward( self, x: torch.Tensor, grid_thw: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, ...]: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) From d625fb714c4a21508684349d202220062376b318 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 08:40:18 +0000 Subject: [PATCH 05/20] revert Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2.py | 10 ---------- vllm/model_executor/models/qwen2_vl.py | 1 - 2 files changed, 11 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 18fe078052e..9aaeeb7bf7e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -437,16 +437,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a3190556819..1de03e61285 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -718,7 +718,6 @@ def _parse_video_data( return super()._parse_video_data(data) - class Qwen2VLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): From 4cd6e8faefdc139a36305699744c2e02c59e4db9 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 08:43:30 +0000 Subject: [PATCH 06/20] remove unused code Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1de03e61285..c4e0c785eb8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1230,17 +1230,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return modalities - def _merge_multimodal_embeddings( - self, - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - multimodal_embeddings: torch.Tensor, - placeholder_token_id: int, - ) -> torch.Tensor: - mask = (input_ids == placeholder_token_id) - inputs_embeds[mask, :] = multimodal_embeddings - return inputs_embeds - def get_multimodal_embeddings( self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: From 6d7850b3386950db8e85181883d0c3ea95a35f9e Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 09:32:30 +0000 Subject: [PATCH 07/20] address dynamic dims Signed-off-by: Roger Wang --- vllm/compilation/decorators.py | 26 +++++++++++++++----------- vllm/model_executor/models/qwen2.py | 5 +++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 7e59cf87b50..6956ab9c33f 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -18,15 +18,12 @@ logger = init_logger(__name__) _T = TypeVar("_T", bound=type[nn.Module]) -DimIndexes = Union[int, List[int]] -DimIndexesSelector = Callable[[torch.Tensor], DimIndexes] @overload def support_torch_compile( *, - dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes, - DimIndexesSelector]]], + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], ) -> Callable[[_T], _T]: ... @@ -39,8 +36,7 @@ def support_torch_compile(cls: _T) -> _T: def support_torch_compile( cls: Optional[_T] = None, *, - dynamic_arg_dims: Optional[Dict[str, Union[DimIndexes, - DimIndexesSelector]]] = None, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -136,7 +132,7 @@ def cls_decorator_helper(cls: _T) -> _T: def _support_torch_compile( cls: _T, - dynamic_arg_dims: Dict[str, Union[DimIndexes, DimIndexesSelector]], + dynamic_arg_dims: Dict[str, Union[int, List[int]]], ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -184,13 +180,21 @@ def __call__(self, *args, **kwargs): for k, dims in dynamic_arg_dims.items(): arg = bound_args.arguments.get(k) if arg is not None: + dims = [dims] if isinstance(dims, int) else dims if isinstance(arg, torch.Tensor): - dims_ = dims(arg) if callable(dims) else dims - torch._dynamo.mark_dynamic(arg, dims_) + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): - dims_ = dims(tensor) if callable(dims) else dims - torch._dynamo.mark_dynamic(tensor, dims_) + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) else: raise ValueError( "Unsupported dynamic dimensions" diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 9aaeeb7bf7e..82de1c35740 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -259,8 +259,9 @@ def forward( @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, - # dim 1 for mrope in shape (3, seq_len), else dim 0 in shape (seq_len, ) - "positions": lambda tensor: tensor.ndim - 1, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, }) From 191f941e5042635de756ecbd576509ffdfdfc7cd Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 09:35:30 +0000 Subject: [PATCH 08/20] doc and co-author Co-authored-by: imkero Signed-off-by: Roger Wang --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index d07cde3db5c..5b64a75d7fa 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ - `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ - - + - ✅︎ * - `UltravoxModel` - Ultravox - T + AE+ From 641d3c271738b936db1d47675911f6b40287b8c9 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 09:44:33 +0000 Subject: [PATCH 09/20] revert bnb Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c4e0c785eb8..8979142c5ff 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1052,16 +1052,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] - # BitandBytes specific attributes - bitsandbytes_stacked_params_mapping = { - # shard_name, weight_name, index - "q_proj": ("qkv_proj", 0), - "k_proj": ("qkv_proj", 1), - "v_proj": ("qkv_proj", 2), - "gate_proj": ("gate_up_proj", 0), - "up_proj": ("gate_up_proj", 1), - } - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", From 51e395997ab87dd04aef98061a4e2f9398b922d5 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 17 Jan 2025 09:51:45 +0000 Subject: [PATCH 10/20] comment Signed-off-by: Roger Wang --- vllm/compilation/decorators.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 6956ab9c33f..38f284794b8 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -76,11 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): During runtime, when we actually mark dimensions of tensors, it depends on the value of arguments: - - if it is a single integer, the corresponding dimension of the argument - will be marked as dynamic. - - if it is a function returns a single integer, it will be called with - the tensor as argument, and the returned dimension will be marked - as dynamic. + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. - if it is `None`, ignored. - if it is `IntermediateTensors`, all the tensors in the intermediate tensors will be marked as dynamic. From 64e2932ac56e565faef1c17d837ef96d3472757b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 18 Jan 2025 08:25:11 +0000 Subject: [PATCH 11/20] comment Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 16 ++++++---------- vllm/v1/worker/gpu_model_runner.py | 5 +++++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8979142c5ff..a9c93395d1d 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -944,7 +944,6 @@ def get_dummy_processor_inputs( class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): - _placeholder_map: Optional[dict[str, list[int]]] = None def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() @@ -959,15 +958,12 @@ def _get_prompt_replacements( image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) - if not self._placeholder_map: - # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has - # image_token and video_token registered - encode_fn = hf_processor.tokenizer.encode - self._placeholder_map = { - "image": encode_fn(hf_processor.image_token), - "video": encode_fn(hf_processor.video_token), - } - placeholder = self._placeholder_map + # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has + # image_token and video_token registered + placeholder = { + "image": hf_processor.image_token, + "video": hf_processor.video_token, + } merge_length = image_processor.merge_size**2 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 240924ad484..50067dcec63 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -141,6 +141,7 @@ def __init__( dtype=torch.int64, device=self.device) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: # a permuted mrope_positions tensor satisfying the following # properties to allow `torch.compile` work properly: @@ -266,6 +267,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: output_token_ids=[], ) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: image_grid_thw = [] video_grid_thw = [] @@ -361,6 +363,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): out=positions_np) # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -411,10 +414,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) if self.model_config.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: + # Common case (1D positions) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) From 31b1e67e07ff81f7da72294b8fb78f979db10fa1 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 18 Jan 2025 09:12:57 +0000 Subject: [PATCH 12/20] comment Signed-off-by: Roger Wang --- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 50067dcec63..d7b7d7cf731 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -143,10 +143,12 @@ def __init__( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: - # a permuted mrope_positions tensor satisfying the following - # properties to allow `torch.compile` work properly: + # NOTE: `mrope_positions` is implemented as a permuted tensor to + # satisfy the following properties to allow `torch.compile` to work + # properly: # - shape: (3, ) - # - stride: (1, ) + # - stride: (1, 3) + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256 self.mrope_positions = torch.zeros((self.max_num_tokens, 3), dtype=torch.int64, From a71d8f087bde07a3901334e12c1fb75b16415346 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 18 Jan 2025 23:54:36 +0000 Subject: [PATCH 13/20] comment Signed-off-by: Roger Wang --- vllm/v1/worker/gpu_model_runner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d7b7d7cf731..87a1cd7f9e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -150,6 +150,11 @@ def __init__( # - stride: (1, 3) # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256 + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros((self.max_num_tokens, 3), dtype=torch.int64, device=self.device) From 7c5d95ac779cbca4fdc474c7bea3ff9620259a4f Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 19 Jan 2025 04:42:57 +0000 Subject: [PATCH 14/20] fix shape Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 42 +++++++++++++------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a9c93395d1d..9b65c3fed29 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -98,8 +98,6 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): """Supported types: - List[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -138,8 +136,6 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): """Supported types: - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -594,7 +590,7 @@ def forward( self, x: torch.Tensor, grid_thw: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: + ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) @@ -616,10 +612,6 @@ def forward( # adapter x = self.merger(x) - # split by individual data items - sizes = grid_thw.prod( - -1) // self.spatial_merge_size // self.spatial_merge_size - x = x.split(sizes.tolist()) return x def load_weights(self, weights: Iterable[Tuple[str, @@ -1180,26 +1172,34 @@ def _parse_and_validate_video_input( video_embeds=video_embeds, video_grid_thw=video_grid_thw) - def _process_image_input(self, - image_input: Qwen2VLImageInputs) -> torch.Tensor: + def _process_image_input( + self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": - return image_input["image_embeds"].type(self.visual.dtype) + return tuple(image_input["image_embeds"].type(self.visual.dtype)) + grid_thw = image_input["image_grid_thw"] + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod() // merge_size // merge_size pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, - grid_thw=image_input["image_grid_thw"]) - return image_embeds - def _process_video_input(self, - video_input: Qwen2VLVideoInputs) -> torch.Tensor: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: if video_input["type"] == "video_embeds": - return video_input["video_embeds"].type(self.visual.dtype) + return tuple(video_input["video_embeds"].type(self.visual.dtype)) + grid_thw = video_input["video_grid_thw"] + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod() // merge_size // merge_size pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, - grid_thw=video_input["video_grid_thw"]) - return video_embeds + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + return video_embeds.split(sizes.tolist()) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} From 21da0a71431898a9bcc0bee224ac214f783b7e80 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 19 Jan 2025 05:13:00 +0000 Subject: [PATCH 15/20] fix Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9b65c3fed29..646b1bf0dc7 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -98,6 +98,8 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): """Supported types: - List[`torch.Tensor`]: A list of tensors holding all images' features. Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -136,6 +138,8 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): """Supported types: - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on @@ -1175,26 +1179,27 @@ def _parse_and_validate_video_input( def _process_image_input( self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: - if image_input["type"] == "image_embeds": - return tuple(image_input["image_embeds"].type(self.visual.dtype)) - grid_thw = image_input["image_grid_thw"] merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod() // merge_size // merge_size - pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + return image_embeds.split(sizes.tolist()) + pixel_values = image_input["pixel_values"].type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=grid_thw) return image_embeds.split(sizes.tolist()) def _process_video_input( self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - if video_input["type"] == "video_embeds": - return tuple(video_input["video_embeds"].type(self.visual.dtype)) - grid_thw = video_input["video_grid_thw"] merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod() // merge_size // merge_size + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + return video_embeds.split(sizes.tolist()) + pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) From 4c37351b9f3f0804df1ce179f3ad6239d21565af Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 19 Jan 2025 05:16:17 +0000 Subject: [PATCH 16/20] comments Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 646b1bf0dc7..9339f795ab5 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -138,8 +138,8 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): """Supported types: - List[`torch.Tensor`]: A list of tensors holding all videos' features. Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). + - `torch.Tensor`: A tensor holding all videos' features + (concatenation of all videos' feature tensors). Tensor shape: `(num_image_features, hidden_size)` - `num_image_features` varies based on From e6340e57af103805504ca6a7fd5ab7c4e1c716c8 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 19 Jan 2025 05:25:20 +0000 Subject: [PATCH 17/20] cleanup and assert Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9339f795ab5..18c36ef41b9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1180,8 +1180,10 @@ def _process_image_input( self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod() // merge_size // merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) return image_embeds.split(sizes.tolist()) @@ -1193,9 +1195,12 @@ def _process_image_input( def _process_video_input( self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod() // merge_size // merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) return video_embeds.split(sizes.tolist()) From 85b7e87fe37b5c1a1fe5c0b3a9bac3b715d21321 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 19 Jan 2025 05:39:31 +0000 Subject: [PATCH 18/20] cleanup Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen2_vl.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 18c36ef41b9..5051aa5844c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1181,15 +1181,16 @@ def _process_image_input( grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) - return image_embeds.split(sizes.tolist()) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size return image_embeds.split(sizes.tolist()) @@ -1198,16 +1199,17 @@ def _process_video_input( grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) - return video_embeds.split(sizes.tolist()) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size return video_embeds.split(sizes.tolist()) From 09668f75bc65c36ac9e96611dc4bdf38fc828dbb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 19 Jan 2025 07:43:11 +0000 Subject: [PATCH 19/20] Small improvements to the tests Signed-off-by: DarkLight1337 --- .../vision_language/test_qwen2_vl.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 16e256e040a..2fd22f0cc88 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -105,7 +105,7 @@ def batch_make_image_embeddings( pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker. \ model_runner.model.visual @@ -124,11 +124,10 @@ def batch_make_image_embeddings( for image_batch in image_batches_: cur_batch_image_count = len(image_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count] - ]) + cur_batch_image_count]) result.append({ "image_embeds": @@ -187,7 +186,7 @@ def batch_make_video_embeddings( pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] - # pixel values to embeddinds & grid_thws + # pixel values to embeddings & grid_thws with torch.no_grad(): visual = llm.llm_engine.model_executor.driver_worker.\ model_runner.model.visual @@ -206,11 +205,10 @@ def batch_make_video_embeddings( for video_batch in video_batches_: cur_batch_video_count = len(video_batch) merge_size = image_processor.merge_size - cur_batch_embed_len = sum([ - grid_thw.prod() // merge_size // merge_size + cur_batch_embed_len = sum( + grid_thw.prod(-1) // merge_size // merge_size for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count] - ]) + cur_batch_video_count]) result.append({ "video_embeds": From 562e0b7af509b9d669af5c067b189fa4820ef9d5 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 19 Jan 2025 07:43:48 +0000 Subject: [PATCH 20/20] Fix embedding inputs being ignored Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava_onevision.py | 6 ++++-- vllm/model_executor/models/qwen2_vl.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index c9283e0c5ba..6faa79f65d8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key == "pixel_values" and "images" not in modalities: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: modalities["images"] = self._parse_and_validate_image_input( **kwargs) - if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501 + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: modalities["videos"] = self._parse_and_validate_video_input( **kwargs) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 5051aa5844c..34d5c8ad089 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1219,10 +1219,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key == "pixel_values" and "images" not in modalities: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: modalities["images"] = self._parse_and_validate_image_input( **kwargs) - if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501 + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: modalities["videos"] = self._parse_and_validate_video_input( **kwargs)