diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 2edb610ddf9..eb1bde9ec00 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -754,7 +754,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc.
- ✅︎
- ✅︎
- -
+ - ✅︎
* - `UltravoxModel`
- Ultravox
- T + AE+
diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py
index 16e256e040a..2fd22f0cc88 100644
--- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py
+++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py
@@ -105,7 +105,7 @@ def batch_make_image_embeddings(
pixel_values = preprocess_result["pixel_values"]
image_grid_thw = preprocess_result["image_grid_thw"]
- # pixel values to embeddinds & grid_thws
+ # pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
@@ -124,11 +124,10 @@ def batch_make_image_embeddings(
for image_batch in image_batches_:
cur_batch_image_count = len(image_batch)
merge_size = image_processor.merge_size
- cur_batch_embed_len = sum([
- grid_thw.prod() // merge_size // merge_size
+ cur_batch_embed_len = sum(
+ grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in image_grid_thw[image_counter:image_counter +
- cur_batch_image_count]
- ])
+ cur_batch_image_count])
result.append({
"image_embeds":
@@ -187,7 +186,7 @@ def batch_make_video_embeddings(
pixel_values = preprocess_result["pixel_values_videos"]
video_grid_thw = preprocess_result["video_grid_thw"]
- # pixel values to embeddinds & grid_thws
+ # pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
@@ -206,11 +205,10 @@ def batch_make_video_embeddings(
for video_batch in video_batches_:
cur_batch_video_count = len(video_batch)
merge_size = image_processor.merge_size
- cur_batch_embed_len = sum([
- grid_thw.prod() // merge_size // merge_size
+ cur_batch_embed_len = sum(
+ grid_thw.prod(-1) // merge_size // merge_size
for grid_thw in video_grid_thw[video_counter:video_counter +
- cur_batch_video_count]
- ])
+ cur_batch_video_count])
result.append({
"video_embeds":
diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py
index 10513111ea7..38f284794b8 100644
--- a/vllm/compilation/decorators.py
+++ b/vllm/compilation/decorators.py
@@ -76,8 +76,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- - if it is a single integer, the corresponding dimension of the argument
- will be marked as dynamic.
+ - if it is a single integer (can be negative), the corresponding dimension
+ of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
@@ -177,10 +177,20 @@ def __call__(self, *args, **kwargs):
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
+ dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
+ # In case dims is specified with negative indexing
+ dims = [
+ arg.ndim + dim if dim < 0 else dim for dim in dims
+ ]
torch._dynamo.mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
+ # In case dims is specified with negative indexing
+ dims = [
+ tensor.ndim + dim if dim < 0 else dim
+ for dim in dims
+ ]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 3fcd81a3c42..d071cfe888f 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -841,6 +841,37 @@ def get_input_positions(
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
+ llm_positions, mrope_position_delta = \
+ MRotaryEmbedding.get_input_positions_tensor(
+ input_tokens,
+ image_grid_thw,
+ video_grid_thw,
+ image_token_id,
+ video_token_id,
+ vision_start_token_id,
+ vision_end_token_id,
+ spatial_merge_size,
+ context_len,
+ seq_len,
+ )
+
+ return llm_positions.tolist(), mrope_position_delta
+
+ @staticmethod
+ def get_input_positions_tensor(
+ input_tokens: List[int],
+ image_grid_thw: Union[List[List[int]], torch.Tensor],
+ video_grid_thw: Union[List[List[int]], torch.Tensor],
+ image_token_id: int,
+ video_token_id: int,
+ vision_start_token_id: int,
+ vision_end_token_id: int,
+ spatial_merge_size: int,
+ context_len: int = 0,
+ seq_len: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, int]:
+ """Get mrope input positions and delta value."""
+
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
@@ -916,7 +947,7 @@ def get_input_positions(
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
- return llm_positions.tolist(), mrope_position_delta
+ return llm_positions, mrope_position_delta
@staticmethod
def get_next_input_positions(
@@ -930,6 +961,17 @@ def get_next_input_positions(
seq_len + mrope_position_delta)) for _ in range(3)
]
+ @staticmethod
+ def get_next_input_positions_tensor(
+ mrope_position_delta: int,
+ context_len: int,
+ seq_len: int,
+ ) -> torch.Tensor:
+ return torch.arange(
+ mrope_position_delta + context_len,
+ mrope_position_delta + seq_len,
+ ).expand(3, -1)
+
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index c9283e0c5ba..6faa79f65d8 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -554,10 +554,12 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
- if input_key == "pixel_values" and "images" not in modalities:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
- if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index d015f60c6d0..82de1c35740 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -256,7 +256,15 @@ def forward(
return hidden_states, residual
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ })
class Qwen2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index d00e5d362c8..34d5c8ad089 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -67,11 +67,15 @@
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper,
- init_vllm_registered_model, maybe_prefix)
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
+# For profile run
+_MAX_FRAMES_PER_VIDEO = 16
+
# === Vision Inputs === #
@@ -135,7 +139,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
- (concatenation of all videos' feature tensors).
+ (concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
@@ -611,6 +615,7 @@ def forward(
# adapter
x = self.merger(x)
+
return x
def load_weights(self, weights: Iterable[Tuple[str,
@@ -874,8 +879,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int:
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
-
- num_frames = max(max_total_frames // max(max_videos, 1), 1)
+ num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
+ _MAX_FRAMES_PER_VIDEO)
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
if num_frames > 1 and num_frames % 2 == 1:
@@ -955,13 +960,14 @@ def _get_prompt_replacements(
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
+
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
- num_tokens = grid_thw.prod() // merge_length
+ num_tokens = grid_thw.prod().item() // merge_length
return placeholder[modality] * num_tokens
return [
@@ -1047,11 +1053,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2VLConfig = vllm_config.model_config.hf_config
- cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
- assert not cache_config.enable_prefix_caching, \
- "Qwen2-VL currently does not support prefix caching"
self.config = config
self.multimodal_config = multimodal_config
@@ -1173,59 +1176,82 @@ def _parse_and_validate_video_input(
video_embeds=video_embeds,
video_grid_thw=video_grid_thw)
- def _process_image_input(self,
- image_input: Qwen2VLImageInputs) -> torch.Tensor:
+ def _process_image_input(
+ self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
if image_input["type"] == "image_embeds":
- return image_input["image_embeds"].type(self.visual.dtype)
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]:
- pixel_values = image_input["pixel_values"].type(self.visual.dtype)
- image_embeds = self.visual(pixel_values,
- grid_thw=image_input["image_grid_thw"])
- return image_embeds
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
- def _process_video_input(self,
- video_input: Qwen2VLVideoInputs) -> torch.Tensor:
if video_input["type"] == "video_embeds":
- return video_input["video_embeds"].type(self.visual.dtype)
+ video_embeds = video_input["video_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
- pixel_values_videos = video_input["pixel_values_videos"].type(
- self.visual.dtype)
- video_embeds = self.visual(pixel_values_videos,
- grid_thw=video_input["video_grid_thw"])
- return video_embeds
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
- def _merge_multimodal_embeddings(
- self,
- input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- multimodal_embeddings: torch.Tensor,
- placeholder_token_id: int,
- ) -> torch.Tensor:
- mask = (input_ids == placeholder_token_id)
- inputs_embeds[mask, :] = multimodal_embeddings
- return inputs_embeds
+ return video_embeds.split(sizes.tolist())
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
+ modalities["videos"] = self._parse_and_validate_video_input(
+ **kwargs)
+
+ return modalities
def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
- image_input = self._parse_and_validate_image_input(**kwargs)
- video_input = self._parse_and_validate_video_input(**kwargs)
- if image_input is None and video_input is None:
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
return None
- # We make a tuple of each embedding with its modality string. This is a
- # temporary workaround for models to handle mixed modalities when
- # get_multimodal_embeddings and get_input_embeddings are called
- # separately.
- # TODO(ywang96): Add support for mixed-modality inference for v1.
- multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
-
- if image_input is not None:
- image_embeds = self._process_image_input(image_input)
- multimodal_embeddings.append((image_embeds, "image"))
- if video_input is not None:
- video_embeds = self._process_video_input(video_input)
- multimodal_embeddings.append((video_embeds, "video"))
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self._process_video_input(video_input)
+ multimodal_embeddings += video_embeddings
return multimodal_embeddings
@@ -1237,21 +1263,9 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
- for embeddings, modality in multimodal_embeddings:
- if modality == "image":
- inputs_embeds = self._merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- embeddings,
- placeholder_token_id=self.config.image_token_id,
- )
- if modality == "video":
- inputs_embeds = self._merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- embeddings,
- placeholder_token_id=self.config.video_token_id,
- )
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [self.config.image_token_id, self.config.video_token_id])
return inputs_embeds
def forward(
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 40494e64b22..28d8e390538 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -30,6 +30,9 @@ class CachedRequestState:
num_computed_tokens: int
output_token_ids: List[int]
+ mrope_positions: Optional[torch.Tensor] = None
+ mrope_position_delta: Optional[int] = None
+
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index aa63d9414c2..87a1cd7f9e6 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -14,6 +14,7 @@
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
+from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
@@ -139,6 +140,32 @@ def __init__(
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ # NOTE: `mrope_positions` is implemented as a permuted tensor to
+ # satisfy the following properties to allow `torch.compile` to work
+ # properly:
+ # - shape: (3, )
+ # - stride: (1, 3)
+ # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
+
+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
+ # the modality of inputs. For text-only inputs, each dimension has
+ # identical position IDs, making M-RoPE functionally equivalent to
+ # 1D-RoPE.
+ # See page 5 of https://arxiv.org/abs/2409.12191
+ self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
+ dtype=torch.int64,
+ device=self.device)
+ self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory)
+
+ self.mrope_positions = self.mrope_positions.permute((1, 0))
+ self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
+
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@@ -246,6 +273,35 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
+
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ image_grid_thw = []
+ video_grid_thw = []
+ for mm_input in self.requests[req_id].mm_inputs:
+ if mm_input.get("image_grid_thw") is not None:
+ image_grid_thw.extend(
+ mm_input["image_grid_thw"].tolist())
+ if mm_input.get("video_grid_thw") is not None:
+ video_grid_thw.extend(
+ mm_input["video_grid_thw"].tolist())
+
+ hf_config = self.model_config.hf_config
+
+ self.requests[req_id].mrope_positions, \
+ self.requests[req_id].mrope_position_delta = \
+ MRotaryEmbedding.get_input_positions_tensor(
+ self.requests[req_id].prompt_token_ids,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ image_token_id=hf_config.image_token_id,
+ video_token_id=hf_config.video_token_id,
+ vision_start_token_id=hf_config.vision_start_token_id,
+ vision_end_token_id=hf_config.vision_end_token_id,
+ spatial_merge_size=hf_config.vision_config.
+ spatial_merge_size,
+ )
+
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
@@ -313,6 +369,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
arange,
out=positions_np)
+ # Calculate M-RoPE positions.
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ if self.model_config.uses_mrope:
+ self._calc_mrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -359,8 +420,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
- self.positions[:total_num_scheduled_tokens].copy_(
- self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
+ if self.model_config.uses_mrope:
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
+ self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
+ self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True)
+ else:
+ # Common case (1D positions)
+ self.positions[:total_num_scheduled_tokens].copy_(
+ self.positions_cpu[:total_num_scheduled_tokens],
+ non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
@@ -472,6 +541,61 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
+ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
+ mrope_pos_ptr = 0
+ num_reqs = self.input_batch.num_reqs
+ for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
+ assert req_id is not None
+
+ req = self.requests[req_id]
+ assert req.mrope_positions is not None
+
+ num_computed_tokens = \
+ self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = \
+ scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = len(req.prompt_token_ids)
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0,
+ num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(
+ 0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's mrope_positions are pre-computed
+ dst_start = mrope_pos_ptr
+ dst_end = mrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ req.mrope_positions[:,src_start:src_end]
+
+ mrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's mrope_positions on-the-fly
+ dst_start = mrope_pos_ptr
+ dst_end = mrope_pos_ptr + completion_part_len
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ MRotaryEmbedding.get_next_input_positions_tensor(
+ req.mrope_position_delta,
+ context_len=num_computed_tokens +
+ prompt_part_len,
+ seq_len=num_computed_tokens +
+ prompt_part_len +
+ completion_part_len,
+ )
+
+ mrope_pos_ptr += completion_part_len
+
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
@@ -618,9 +742,12 @@ def execute_model(
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
+ positions = self.mrope_positions[:, :num_input_tokens] \
+ if self.model_config.uses_mrope \
+ else self.positions[:num_input_tokens]
hidden_states = self.model(
input_ids=input_ids,
- positions=self.positions[:num_input_tokens],
+ positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,
@@ -707,9 +834,12 @@ def _dummy_run(
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
+ positions = self.mrope_positions[:, :num_tokens] \
+ if self.model_config.uses_mrope \
+ else self.positions[:num_tokens]
hidden_states = model(
input_ids=input_ids,
- positions=self.positions[:num_tokens],
+ positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,