|
28 | 28 | import torch
|
29 | 29 | import torch.nn as nn
|
30 | 30 | from transformers import PretrainedConfig
|
| 31 | +from transformers.utils import is_flash_attn_2_available |
31 | 32 |
|
32 | 33 | from vllm.model_executor.custom_op import CustomOp
|
33 | 34 | from vllm.platforms import current_platform
|
34 | 35 |
|
| 36 | +if is_flash_attn_2_available(): |
| 37 | + from flash_attn.ops.triton.rotary import apply_rotary |
35 | 38 |
|
36 | 39 | def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
37 | 40 | x1 = x[..., :x.shape[-1] // 2]
|
@@ -100,6 +103,10 @@ def __init__(
|
100 | 103 | cache = cache.to(dtype)
|
101 | 104 | self.cos_sin_cache: torch.Tensor
|
102 | 105 | self.register_buffer("cos_sin_cache", cache, persistent=False)
|
| 106 | + if is_flash_attn_2_available(): |
| 107 | + self._use_flash_attn = True |
| 108 | + else: |
| 109 | + self._use_flash_attn = False |
103 | 110 |
|
104 | 111 | def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
105 | 112 | """Compute the inverse frequency."""
|
@@ -141,14 +148,23 @@ def forward_native(
|
141 | 148 | query = query.view(num_tokens, -1, self.head_size)
|
142 | 149 | query_rot = query[..., :self.rotary_dim]
|
143 | 150 | query_pass = query[..., self.rotary_dim:]
|
144 |
| - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) |
| 151 | + if self._use_flash_attn: |
| 152 | + query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin, |
| 153 | + 0).squeeze(0) |
| 154 | + else: |
| 155 | + query_rot = _apply_rotary_emb(query_rot, cos, sin, |
| 156 | + self.is_neox_style) |
145 | 157 | query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
146 | 158 |
|
147 | 159 | key_shape = key.shape
|
148 | 160 | key = key.view(num_tokens, -1, self.head_size)
|
149 | 161 | key_rot = key[..., :self.rotary_dim]
|
150 | 162 | key_pass = key[..., self.rotary_dim:]
|
151 |
| - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
| 163 | + if self._use_flash_attn: |
| 164 | + key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin, |
| 165 | + 0).squeeze(0) |
| 166 | + else: |
| 167 | + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
152 | 168 | key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
153 | 169 | return query, key
|
154 | 170 |
|
@@ -938,6 +954,10 @@ def __init__(
|
938 | 954 | self.mrope_section = mrope_section
|
939 | 955 | if self.mrope_section:
|
940 | 956 | assert sum(self.mrope_section) == rotary_dim // 2
|
| 957 | + if is_flash_attn_2_available(): |
| 958 | + self._use_flash_attn = True |
| 959 | + else: |
| 960 | + self._use_flash_attn = False |
941 | 961 |
|
942 | 962 | def forward(
|
943 | 963 | self,
|
@@ -977,14 +997,23 @@ def forward(
|
977 | 997 | query = query.view(num_tokens, -1, self.head_size)
|
978 | 998 | query_rot = query[..., :self.rotary_dim]
|
979 | 999 | query_pass = query[..., self.rotary_dim:]
|
980 |
| - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) |
| 1000 | + if self._use_flash_attn: |
| 1001 | + query_rot = apply_rotary(query_rot.unsqueeze(0), cos, sin, |
| 1002 | + 0).squeeze(0) |
| 1003 | + else: |
| 1004 | + query_rot = _apply_rotary_emb(query_rot, cos, sin, |
| 1005 | + self.is_neox_style) |
981 | 1006 | query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
982 | 1007 |
|
983 | 1008 | key_shape = key.shape
|
984 | 1009 | key = key.view(num_tokens, -1, self.head_size)
|
985 | 1010 | key_rot = key[..., :self.rotary_dim]
|
986 | 1011 | key_pass = key[..., self.rotary_dim:]
|
987 |
| - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
| 1012 | + if self._use_flash_attn: |
| 1013 | + key_rot = apply_rotary(key_rot.unsqueeze(0), cos, sin, |
| 1014 | + 0).squeeze(0) |
| 1015 | + else: |
| 1016 | + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) |
988 | 1017 | key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
989 | 1018 | return query, key
|
990 | 1019 |
|
|
0 commit comments