Skip to content

Commit a115f1f

Browse files
[Perf]: Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit a115f1f

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.utils import is_flash_attn_2_available
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
3435

36+
if is_flash_attn_2_available():
37+
from flash_attn.ops.triton.rotary import apply_rotary
3538

3639
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
3740
x1 = x[..., :x.shape[-1] // 2]
@@ -100,6 +103,10 @@ def __init__(
100103
cache = cache.to(dtype)
101104
self.cos_sin_cache: torch.Tensor
102105
self.register_buffer("cos_sin_cache", cache, persistent=False)
106+
if is_flash_attn_2_available():
107+
self._use_flash_attn = True
108+
else:
109+
self._use_flash_attn = False
103110

104111
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
105112
"""Compute the inverse frequency."""
@@ -141,14 +148,23 @@ def forward_native(
141148
query = query.view(num_tokens, -1, self.head_size)
142149
query_rot = query[..., :self.rotary_dim]
143150
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
151+
if self._use_flash_attn:
152+
query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin,
153+
0).squeeze(0)
154+
else:
155+
query_rot = _apply_rotary_emb(query_rot, cos, sin,
156+
self.is_neox_style)
145157
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146158

147159
key_shape = key.shape
148160
key = key.view(num_tokens, -1, self.head_size)
149161
key_rot = key[..., :self.rotary_dim]
150162
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
163+
if self._use_flash_attn:
164+
key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin,
165+
0).squeeze(0)
166+
else:
167+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
152168
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153169
return query, key
154170

@@ -938,6 +954,10 @@ def __init__(
938954
self.mrope_section = mrope_section
939955
if self.mrope_section:
940956
assert sum(self.mrope_section) == rotary_dim // 2
957+
if is_flash_attn_2_available():
958+
self._use_flash_attn = True
959+
else:
960+
self._use_flash_attn = False
941961

942962
def forward(
943963
self,
@@ -977,14 +997,23 @@ def forward(
977997
query = query.view(num_tokens, -1, self.head_size)
978998
query_rot = query[..., :self.rotary_dim]
979999
query_pass = query[..., self.rotary_dim:]
980-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1000+
if self._use_flash_attn:
1001+
query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin,
1002+
0).squeeze(0)
1003+
else:
1004+
query_rot = _apply_rotary_emb(query_rot, cos, sin,
1005+
self.is_neox_style)
9811006
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
9821007

9831008
key_shape = key.shape
9841009
key = key.view(num_tokens, -1, self.head_size)
9851010
key_rot = key[..., :self.rotary_dim]
9861011
key_pass = key[..., self.rotary_dim:]
987-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1012+
if self._use_flash_attn:
1013+
key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin,
1014+
0).squeeze(0)
1015+
else:
1016+
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
9881017
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
9891018
return query, key
9901019

0 commit comments

Comments
 (0)