28
28
import torch
29
29
import torch .nn as nn
30
30
from transformers import PretrainedConfig
31
+ from transformers .utils import is_flash_attn_2_available
31
32
32
33
from vllm .model_executor .custom_op import CustomOp
33
34
from vllm .platforms import current_platform
@@ -46,20 +47,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
46
47
return x .flatten (- 2 )
47
48
48
49
49
- def _apply_rotary_emb (
50
+ def _apply_rotary_emb_torch (
50
51
x : torch .Tensor ,
51
52
cos : torch .Tensor ,
52
53
sin : torch .Tensor ,
53
54
is_neox_style : bool ,
54
55
) -> torch .Tensor :
55
- """
56
- Args:
57
- x: [num_tokens, num_heads, head_size]
58
- cos: [num_tokens, head_size // 2]
59
- sin: [num_tokens, head_size // 2]
60
- is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61
- positional embeddings.
62
- """
63
56
cos = cos .unsqueeze (- 2 ).to (x .dtype )
64
57
sin = sin .unsqueeze (- 2 ).to (x .dtype )
65
58
if is_neox_style :
@@ -75,6 +68,24 @@ def _apply_rotary_emb(
75
68
return torch .stack ((o1 , o2 ), dim = - 1 ).flatten (- 2 )
76
69
77
70
71
+ def _apply_rotary_emb (x : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor ,
72
+ is_neox_style : bool ) -> torch .Tensor :
73
+ """
74
+ Args:
75
+ x: [num_tokens, num_heads, head_size]
76
+ cos: [num_tokens, head_size // 2]
77
+ sin: [num_tokens, head_size // 2]
78
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
79
+ positional embeddings.
80
+ """
81
+ if is_flash_attn_2_available ():
82
+ from flash_attn .layers .rotary import apply_rotary_emb
83
+ return apply_rotary_emb (x .unsqueeze (0 ), cos , sin ,
84
+ not is_neox_style ).squeeze (0 )
85
+ else :
86
+ return _apply_rotary_emb_torch (x , cos , sin , is_neox_style )
87
+
88
+
78
89
@CustomOp .register ("rotary_embedding" )
79
90
class RotaryEmbedding (CustomOp ):
80
91
"""Original rotary positional embedding."""
@@ -141,14 +152,16 @@ def forward_native(
141
152
query = query .view (num_tokens , - 1 , self .head_size )
142
153
query_rot = query [..., :self .rotary_dim ]
143
154
query_pass = query [..., self .rotary_dim :]
144
- query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style )
155
+ query_rot = _apply_rotary_emb_torch (query_rot , cos , sin ,
156
+ self .is_neox_style )
145
157
query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
146
158
147
159
key_shape = key .shape
148
160
key = key .view (num_tokens , - 1 , self .head_size )
149
161
key_rot = key [..., :self .rotary_dim ]
150
162
key_pass = key [..., self .rotary_dim :]
151
- key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style )
163
+ key_rot = _apply_rotary_emb_torch (key_rot , cos , sin ,
164
+ self .is_neox_style )
152
165
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
153
166
return query , key
154
167
0 commit comments