Skip to content

Commit c5b958c

Browse files
authored
Merge pull request #54 from younesbelkada/add-blip2-accelerate
add `accelerate` support for `blip2`
2 parents 429576a + 037543b commit c5b958c

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,15 @@ def __init__(self, config):
142142
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
143143

144144
if config.qkv_bias:
145-
self.q_bias = nn.Parameter(torch.zeros(self.embed_dim))
146-
self.v_bias = nn.Parameter(torch.zeros(self.embed_dim))
145+
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
146+
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
147147
else:
148-
self.q_bias = None
149-
self.v_bias = None
148+
q_bias = None
149+
v_bias = None
150+
151+
if q_bias is not None:
152+
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
153+
self.qkv.bias = nn.Parameter(qkv_bias)
150154

151155
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
152156

@@ -163,11 +167,7 @@ def forward(
163167

164168
bsz, tgt_len, embed_dim = hidden_states.size()
165169

166-
qkv_bias = None
167-
if self.q_bias is not None:
168-
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
169-
170-
mixed_qkv = nn.functional.linear(input=hidden_states, weight=self.qkv.weight, bias=qkv_bias)
170+
mixed_qkv = self.qkv(hidden_states)
171171

172172
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
173173
2, 0, 3, 1, 4
@@ -285,6 +285,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
285285
r"language_model.encoder.embed_tokens.weight",
286286
r"language_model.decoder.embed_tokens.weight",
287287
]
288+
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
288289

289290
def _init_weights(self, module):
290291
"""Initialize the weights"""
@@ -1335,7 +1336,8 @@ def forward(
13351336

13361337
if attention_mask is None:
13371338
attention_mask = torch.ones_like(input_ids)
1338-
attention_mask = torch.cat([language_model_attention_mask, attention_mask], dim=1)
1339+
expected_device = language_model_attention_mask.device
1340+
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)
13391341

13401342
if self.config.use_decoder_only_language_model:
13411343
outputs = self.language_model(
@@ -1352,10 +1354,11 @@ def forward(
13521354
logits = logits[:, -labels.size(1) :, :]
13531355
# Shift so that tokens < n predict n
13541356
shift_logits = logits[..., :-1, :].contiguous()
1355-
shift_labels = labels[..., 1:].contiguous()
1357+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
13561358

13571359
# Flatten the tokens
13581360
loss_fct = CrossEntropyLoss(reduction="mean")
1361+
13591362
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
13601363
else:
13611364
outputs = self.language_model(

0 commit comments

Comments
 (0)