18
18
from vllm .model_executor .layers .pooler import (CrossEncodingPooler , Pooler ,
19
19
PoolingType )
20
20
from vllm .model_executor .layers .quantization import QuantizationConfig
21
+ from vllm .model_executor .layers .rotary_embedding import get_rope
21
22
from vllm .model_executor .layers .vocab_parallel_embedding import (
22
23
VocabParallelEmbedding )
23
24
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
@@ -38,19 +39,24 @@ def __init__(self, config: BertConfig):
38
39
self .size = config .hidden_size
39
40
self .word_embeddings = VocabParallelEmbedding (config .vocab_size ,
40
41
config .hidden_size )
41
- self .position_embeddings = VocabParallelEmbedding (
42
- config .max_position_embeddings , config .hidden_size )
42
+
43
43
self .token_type_embeddings = VocabParallelEmbedding (
44
44
config .type_vocab_size , config .hidden_size )
45
45
self .LayerNorm = nn .LayerNorm (config .hidden_size ,
46
46
eps = config .layer_norm_eps )
47
- self .position_ids = nn .Parameter (
48
- torch .empty ((1 , config .max_position_embeddings )), )
49
47
50
48
self .position_embedding_type = config .position_embedding_type
51
- if self .position_embedding_type != "absolute" :
52
- raise ValueError ("Only 'absolute' position_embedding_type" +
53
- " is supported" )
49
+ if self .position_embedding_type == "absolute" :
50
+ self .position_embeddings = VocabParallelEmbedding (
51
+ config .max_position_embeddings , config .hidden_size )
52
+ self .position_ids = nn .Parameter (
53
+ torch .empty ((1 , config .max_position_embeddings )), )
54
+ elif self .position_embedding_type == "rotary" :
55
+ self .position_embeddings = None
56
+ self .position_ids = None
57
+ else :
58
+ raise ValueError ("Only 'absolute' and 'rotary' " +
59
+ "position_embedding_type is supported" )
54
60
55
61
def forward (
56
62
self ,
@@ -64,17 +70,19 @@ def forward(
64
70
# Input embeddings.
65
71
inputs_embeds = self .word_embeddings (input_ids )
66
72
67
- # Position embeddings.
68
- position_embeddings = self .position_embeddings (position_ids )
69
-
70
73
if token_type_ids is None :
71
74
token_type_ids = torch .zeros (input_shape ,
72
75
dtype = torch .long ,
73
76
device = inputs_embeds .device )
74
77
75
78
token_type_embeddings = self .token_type_embeddings (token_type_ids )
76
79
77
- embeddings = inputs_embeds + token_type_embeddings + position_embeddings
80
+ embeddings = inputs_embeds + token_type_embeddings
81
+
82
+ if self .position_embedding_type == "absolute" :
83
+ position_embeddings = self .position_embeddings (position_ids )
84
+ embeddings += position_embeddings
85
+
78
86
embeddings = self .LayerNorm (embeddings )
79
87
return embeddings
80
88
@@ -98,7 +106,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98
106
@support_torch_compile
99
107
class BertEncoder (nn .Module ):
100
108
101
- def __init__ (self , vllm_config : VllmConfig , prefix : str = "" ):
109
+ def __init__ (self ,
110
+ vllm_config : VllmConfig ,
111
+ rotary_kwargs : Optional [dict ] = None ,
112
+ prefix : str = "" ):
102
113
super ().__init__ ()
103
114
config = vllm_config .model_config .hf_config
104
115
cache_config = vllm_config .cache_config
@@ -107,16 +118,18 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
107
118
BertLayer (config = config ,
108
119
cache_config = cache_config ,
109
120
quant_config = quant_config ,
121
+ rotary_kwargs = rotary_kwargs ,
110
122
prefix = f"{ prefix } .layer.{ layer_idx } " )
111
123
for layer_idx in range (config .num_hidden_layers )
112
124
])
113
125
114
126
def forward (
115
127
self ,
128
+ positions : torch .Tensor ,
116
129
hidden_states : torch .Tensor ,
117
130
) -> torch .Tensor :
118
131
for layer in self .layer :
119
- hidden_states = layer (hidden_states )
132
+ hidden_states = layer (positions , hidden_states )
120
133
return hidden_states
121
134
122
135
@@ -126,6 +139,7 @@ def __init__(self,
126
139
config : BertConfig ,
127
140
cache_config : Optional [CacheConfig ] = None ,
128
141
quant_config : Optional [QuantizationConfig ] = None ,
142
+ rotary_kwargs : Optional [dict ] = None ,
129
143
prefix : str = "" ):
130
144
super ().__init__ ()
131
145
@@ -135,6 +149,7 @@ def __init__(self,
135
149
layer_norm_eps = config .layer_norm_eps ,
136
150
cache_config = cache_config ,
137
151
quant_config = quant_config ,
152
+ rotary_kwargs = rotary_kwargs ,
138
153
prefix = f"{ prefix } .attention" )
139
154
140
155
self .intermediate = BertIntermediate (
@@ -150,8 +165,8 @@ def __init__(self,
150
165
quant_config = quant_config ,
151
166
prefix = f"{ prefix } .output" )
152
167
153
- def forward (self , hidden_states : torch .Tensor ):
154
- attn_output = self .attention (hidden_states )
168
+ def forward (self , positions : torch . Tensor , hidden_states : torch .Tensor ):
169
+ attn_output = self .attention (positions , hidden_states )
155
170
intermediate_output = self .intermediate (attn_output )
156
171
output = self .output (intermediate_output , attn_output )
157
172
return output
@@ -166,6 +181,7 @@ def __init__(
166
181
layer_norm_eps : float ,
167
182
cache_config : Optional [CacheConfig ] = None ,
168
183
quant_config : Optional [QuantizationConfig ] = None ,
184
+ rotary_kwargs : Optional [dict ] = None ,
169
185
prefix : str = "" ,
170
186
):
171
187
super ().__init__ ()
@@ -174,6 +190,7 @@ def __init__(
174
190
num_attention_heads = num_attention_heads ,
175
191
cache_config = cache_config ,
176
192
quant_config = quant_config ,
193
+ rotary_kwargs = rotary_kwargs ,
177
194
prefix = f"{ prefix } .output" )
178
195
179
196
self .output = BertSelfOutput (hidden_size = hidden_size ,
@@ -183,9 +200,10 @@ def __init__(
183
200
184
201
def forward (
185
202
self ,
203
+ positions : torch .Tensor ,
186
204
hidden_states : torch .Tensor ,
187
205
) -> torch .Tensor :
188
- self_output = self .self (hidden_states )
206
+ self_output = self .self (positions , hidden_states )
189
207
return self .output (self_output , hidden_states )
190
208
191
209
@@ -197,6 +215,7 @@ def __init__(
197
215
num_attention_heads : int ,
198
216
cache_config : Optional [CacheConfig ] = None ,
199
217
quant_config : Optional [QuantizationConfig ] = None ,
218
+ rotary_kwargs : Optional [dict ] = None ,
200
219
prefix : str = "" ,
201
220
):
202
221
super ().__init__ ()
@@ -225,6 +244,11 @@ def __init__(
225
244
quant_config = quant_config ,
226
245
prefix = f"{ prefix } .qkv_proj" )
227
246
247
+ if rotary_kwargs :
248
+ self .rotary_emb = get_rope (** rotary_kwargs )
249
+ else :
250
+ self .rotary_emb = None
251
+
228
252
self .attn = Attention (num_heads = self .num_heads ,
229
253
head_size = self .head_dim ,
230
254
scale = self .scaling ,
@@ -236,10 +260,15 @@ def __init__(
236
260
237
261
def forward (
238
262
self ,
263
+ positions : torch .Tensor ,
239
264
hidden_states : torch .Tensor ,
240
265
) -> torch .Tensor :
241
266
qkv , _ = self .qkv_proj (hidden_states )
242
267
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
268
+
269
+ if self .rotary_emb :
270
+ q , k = self .rotary_emb (positions , q , k )
271
+
243
272
output = self .attn (q , k , v )
244
273
return output
245
274
@@ -321,11 +350,13 @@ def __init__(self,
321
350
vllm_config : VllmConfig ,
322
351
prefix : str = "" ,
323
352
embedding_class : type = BertEmbedding ,
353
+ rotary_kwargs : Optional [dict ] = None ,
324
354
add_pooling_layer : bool = False ):
325
355
super ().__init__ ()
326
356
config = vllm_config .model_config .hf_config
327
357
self .embeddings = embedding_class (config )
328
358
self .encoder = BertEncoder (vllm_config = vllm_config ,
359
+ rotary_kwargs = rotary_kwargs ,
329
360
prefix = f"{ prefix } .encoder" )
330
361
self .pooler = BertPooler (config ) if add_pooling_layer else None
331
362
@@ -347,7 +378,7 @@ def forward(
347
378
seq_lens = attn_metadata .seq_lens_tensor ,
348
379
position_ids = position_ids ,
349
380
token_type_ids = token_type_ids )
350
- return self .encoder (hidden_states )
381
+ return self .encoder (position_ids , hidden_states )
351
382
352
383
def load_weights (self , weights : Iterable [Tuple [str ,
353
384
torch .Tensor ]]) -> Set [str ]:
@@ -401,6 +432,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
401
432
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
402
433
super ().__init__ ()
403
434
pooler_config = vllm_config .model_config .pooler_config
435
+ self .config = vllm_config .model_config .hf_config
404
436
self .model = self ._build_model (vllm_config = vllm_config ,
405
437
prefix = maybe_prefix (prefix , "model" ))
406
438
self ._pooler = self ._build_pooler (pooler_config )
0 commit comments