Skip to content

Commit bf7c6fc

Browse files
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit bf7c6fc

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
import torch.nn as nn
3030
from transformers import PretrainedConfig
31+
from transformers.utils import is_flash_attn_2_available
3132

3233
from vllm.model_executor.custom_op import CustomOp
3334
from vllm.platforms import current_platform
@@ -46,20 +47,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4647
return x.flatten(-2)
4748

4849

49-
def _apply_rotary_emb(
50+
def _apply_rotary_emb_torch(
5051
x: torch.Tensor,
5152
cos: torch.Tensor,
5253
sin: torch.Tensor,
5354
is_neox_style: bool,
5455
) -> torch.Tensor:
55-
"""
56-
Args:
57-
x: [num_tokens, num_heads, head_size]
58-
cos: [num_tokens, head_size // 2]
59-
sin: [num_tokens, head_size // 2]
60-
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61-
positional embeddings.
62-
"""
6356
cos = cos.unsqueeze(-2).to(x.dtype)
6457
sin = sin.unsqueeze(-2).to(x.dtype)
6558
if is_neox_style:
@@ -75,6 +68,24 @@ def _apply_rotary_emb(
7568
return torch.stack((o1, o2), dim=-1).flatten(-2)
7669

7770

71+
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
72+
is_neox_style: bool) -> torch.Tensor:
73+
"""
74+
Args:
75+
x: [num_tokens, num_heads, head_size]
76+
cos: [num_tokens, head_size // 2]
77+
sin: [num_tokens, head_size // 2]
78+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
79+
positional embeddings.
80+
"""
81+
if is_flash_attn_2_available():
82+
from flash_attn.layers.rotary import apply_rotary_emb
83+
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
84+
not is_neox_style).squeeze(0)
85+
else:
86+
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
87+
88+
7889
@CustomOp.register("rotary_embedding")
7990
class RotaryEmbedding(CustomOp):
8091
"""Original rotary positional embedding."""
@@ -141,14 +152,16 @@ def forward_native(
141152
query = query.view(num_tokens, -1, self.head_size)
142153
query_rot = query[..., :self.rotary_dim]
143154
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
155+
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin,
156+
self.is_neox_style)
145157
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146158

147159
key_shape = key.shape
148160
key = key.view(num_tokens, -1, self.head_size)
149161
key_rot = key[..., :self.rotary_dim]
150162
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
163+
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
164+
self.is_neox_style)
152165
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153166
return query, key
154167

0 commit comments

Comments
 (0)