Skip to content

Commit 99d01a5

Browse files
ywang96imkero
andauthored
[V1] Simplify M-RoPE (#12352)
Signed-off-by: Roger Wang <[email protected]> Co-authored-by: imkero <[email protected]>
1 parent d07efb3 commit 99d01a5

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,28 +144,24 @@ def __init__(
144144

145145
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
146146
if self.model_config.uses_mrope:
147-
# NOTE: `mrope_positions` is implemented as a permuted tensor to
148-
# satisfy the following properties to allow `torch.compile` to work
149-
# properly:
150-
# - shape: (3, <variable>)
151-
# - stride: (1, 3)
152-
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
147+
# NOTE: `mrope_positions` is implemented with one additional dummy
148+
# position on purpose to make it non-contiguous so that it can work
149+
# with torch compile.
150+
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
153151

154152
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
155153
# the modality of inputs. For text-only inputs, each dimension has
156154
# identical position IDs, making M-RoPE functionally equivalent to
157155
# 1D-RoPE.
158156
# See page 5 of https://arxiv.org/abs/2409.12191
159-
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
157+
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
160158
dtype=torch.int64,
161159
device=self.device)
162-
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
163-
dtype=torch.int64,
164-
device="cpu",
165-
pin_memory=self.pin_memory)
166-
167-
self.mrope_positions = self.mrope_positions.permute((1, 0))
168-
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
160+
self.mrope_positions_cpu = torch.zeros(
161+
(3, self.max_num_tokens + 1),
162+
dtype=torch.int64,
163+
device="cpu",
164+
pin_memory=self.pin_memory)
169165

170166
self.inputs_embeds = torch.zeros(
171167
(self.max_num_tokens, self.hidden_size),

0 commit comments

Comments
 (0)