@@ -142,11 +142,15 @@ def __init__(self, config):
142
142
self .qkv = nn .Linear (self .embed_dim , 3 * self .embed_dim , bias = False )
143
143
144
144
if config .qkv_bias :
145
- self . q_bias = nn .Parameter (torch .zeros (self .embed_dim ))
146
- self . v_bias = nn .Parameter (torch .zeros (self .embed_dim ))
145
+ q_bias = nn .Parameter (torch .zeros (self .embed_dim ))
146
+ v_bias = nn .Parameter (torch .zeros (self .embed_dim ))
147
147
else :
148
- self .q_bias = None
149
- self .v_bias = None
148
+ q_bias = None
149
+ v_bias = None
150
+
151
+ if q_bias is not None :
152
+ qkv_bias = torch .cat ((q_bias , torch .zeros_like (v_bias , requires_grad = False ), v_bias ))
153
+ self .qkv .bias = nn .Parameter (qkv_bias )
150
154
151
155
self .projection = nn .Linear (self .embed_dim , self .embed_dim )
152
156
@@ -163,11 +167,7 @@ def forward(
163
167
164
168
bsz , tgt_len , embed_dim = hidden_states .size ()
165
169
166
- qkv_bias = None
167
- if self .q_bias is not None :
168
- qkv_bias = torch .cat ((self .q_bias , torch .zeros_like (self .v_bias , requires_grad = False ), self .v_bias ))
169
-
170
- mixed_qkv = nn .functional .linear (input = hidden_states , weight = self .qkv .weight , bias = qkv_bias )
170
+ mixed_qkv = self .qkv (hidden_states )
171
171
172
172
mixed_qkv = mixed_qkv .reshape (bsz , tgt_len , 3 , self .num_heads , embed_dim // self .num_heads ).permute (
173
173
2 , 0 , 3 , 1 , 4
@@ -285,6 +285,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
285
285
r"language_model.encoder.embed_tokens.weight" ,
286
286
r"language_model.decoder.embed_tokens.weight" ,
287
287
]
288
+ _no_split_modules = ["Blip2Attention" , "T5Block" , "OPTDecoderLayer" ]
288
289
289
290
def _init_weights (self , module ):
290
291
"""Initialize the weights"""
@@ -1335,7 +1336,8 @@ def forward(
1335
1336
1336
1337
if attention_mask is None :
1337
1338
attention_mask = torch .ones_like (input_ids )
1338
- attention_mask = torch .cat ([language_model_attention_mask , attention_mask ], dim = 1 )
1339
+ expected_device = language_model_attention_mask .device
1340
+ attention_mask = torch .cat ([language_model_attention_mask , attention_mask .to (expected_device )], dim = 1 )
1339
1341
1340
1342
if self .config .use_decoder_only_language_model :
1341
1343
outputs = self .language_model (
@@ -1352,10 +1354,11 @@ def forward(
1352
1354
logits = logits [:, - labels .size (1 ) :, :]
1353
1355
# Shift so that tokens < n predict n
1354
1356
shift_logits = logits [..., :- 1 , :].contiguous ()
1355
- shift_labels = labels [..., 1 :].contiguous ()
1357
+ shift_labels = labels [..., 1 :].contiguous (). to ( logits . device )
1356
1358
1357
1359
# Flatten the tokens
1358
1360
loss_fct = CrossEntropyLoss (reduction = "mean" )
1361
+
1359
1362
loss = loss_fct (shift_logits .view (- 1 , self .config .text_config .vocab_size ), shift_labels .view (- 1 ))
1360
1363
else :
1361
1364
outputs = self .language_model (
0 commit comments