Skip to content

Commit 73354b5

Browse files
[Perf]Optimize MRotaryEmbedding implementation to use cuda operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit 73354b5

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,12 @@ def forward(
956956
"""
957957
assert positions.ndim == 1 or positions.ndim == 2
958958

959+
if current_platform.is_cuda_alike():
960+
from vllm import _custom_ops as ops
961+
ops.rotary_embedding(positions, query, key, self.head_size,
962+
self.cos_sin_cache, self.is_neox_style)
963+
return query, key
964+
959965
num_tokens = positions.shape[-1]
960966
cos_sin = self.cos_sin_cache[positions]
961967
cos, sin = cos_sin.chunk(2, dim=-1)

0 commit comments

Comments
 (0)