@@ -144,28 +144,24 @@ def __init__(
144
144
145
145
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
146
146
if self .model_config .uses_mrope :
147
- # NOTE: `mrope_positions` is implemented as a permuted tensor to
148
- # satisfy the following properties to allow `torch.compile` to work
149
- # properly:
150
- # - shape: (3, <variable>)
151
- # - stride: (1, 3)
152
- # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
147
+ # NOTE: `mrope_positions` is implemented with one additional dummy
148
+ # position on purpose to make it non-contiguous so that it can work
149
+ # with torch compile.
150
+ # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
153
151
154
152
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
155
153
# the modality of inputs. For text-only inputs, each dimension has
156
154
# identical position IDs, making M-RoPE functionally equivalent to
157
155
# 1D-RoPE.
158
156
# See page 5 of https://arxiv.org/abs/2409.12191
159
- self .mrope_positions = torch .zeros ((self .max_num_tokens , 3 ),
157
+ self .mrope_positions = torch .zeros ((3 , self .max_num_tokens + 1 ),
160
158
dtype = torch .int64 ,
161
159
device = self .device )
162
- self .mrope_positions_cpu = torch .zeros ((self .max_num_tokens , 3 ),
163
- dtype = torch .int64 ,
164
- device = "cpu" ,
165
- pin_memory = self .pin_memory )
166
-
167
- self .mrope_positions = self .mrope_positions .permute ((1 , 0 ))
168
- self .mrope_positions_cpu = self .mrope_positions_cpu .permute ((1 , 0 ))
160
+ self .mrope_positions_cpu = torch .zeros (
161
+ (3 , self .max_num_tokens + 1 ),
162
+ dtype = torch .int64 ,
163
+ device = "cpu" ,
164
+ pin_memory = self .pin_memory )
169
165
170
166
self .inputs_embeds = torch .zeros (
171
167
(self .max_num_tokens , self .hidden_size ),
0 commit comments