-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathdistilbert.py
533 lines (432 loc) · 18.4 KB
/
distilbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from ane_transformers.reference.layer_norm import LayerNormANE
import torch
import torch.nn as nn
from transformers.models.distilbert import modeling_distilbert
# Note: Original implementation of distilbert uses an epsilon value of 1e-12
# which is not friendly with the float16 precision that ANE uses by default
EPS = 1e-7
WARN_MSG_FOR_TRAINING_ATTEMPT = \
"This model is optimized for on-device execution only. " \
"Please use the original implementation from Hugging Face for training"
WARN_MSG_FOR_DICT_RETURN = \
"coremltools does not support dict outputs. Please set return_dict=False"
# Note: torch.nn.LayerNorm and ane_transformers.reference.layer_norm.LayerNormANE
# apply scale and bias terms in opposite orders. In order to accurately restore a
# state_dict trained using the former into the the latter, we adjust the bias term
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix +
'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix +
'weight']
return state_dict
class LayerNormANE(LayerNormANE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
class Embeddings(modeling_distilbert.Embeddings):
""" Embeddings module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
setattr(self, 'LayerNorm', LayerNormANE(config.dim, eps=EPS))
class MultiHeadSelfAttention(modeling_distilbert.MultiHeadSelfAttention):
""" MultiHeadSelfAttention module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
setattr(
self, 'q_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'k_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'v_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'out_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
def prune_heads(self, heads):
raise NotImplementedError
def forward(self,
query,
key,
value,
mask,
head_mask=None,
output_attentions=False):
"""
Parameters:
query: torch.tensor(bs, dim, 1, seq_length)
key: torch.tensor(bs, dim, 1, seq_length)
value: torch.tensor(bs, dim, 1, seq_length)
mask: torch.tensor(bs, seq_length) or torch.tensor(bs, seq_length, 1, 1)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
dim, 1, seq_length) Contextualized layer. Optional: only if `output_attentions=True`
"""
# Parse tensor shapes for source and target sequences
assert len(query.size()) == 4 and len(key.size()) == 4 and len(
value.size()) == 4
bs, dim, dummy, seqlen = query.size()
# assert seqlen == key.size(3) and seqlen == value.size(3)
# assert dim == self.dim
# assert dummy == 1
# Project q, k and v
q = self.q_lin(query)
k = self.k_lin(key)
v = self.v_lin(value)
# Validate mask
if mask is not None:
expected_mask_shape = [bs, seqlen, 1, 1]
if mask.dtype == torch.bool:
mask = mask.logical_not().float() * -1e4
elif mask.dtype == torch.int64:
mask = (1 - mask).float() * -1e4
elif mask.dtype != torch.float32:
raise TypeError(f"Unexpected dtype for mask: {mask.dtype}")
if len(mask.size()) == 2:
mask = mask.unsqueeze(2).unsqueeze(2)
if list(mask.size()) != expected_mask_shape:
raise RuntimeError(
f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}"
)
if head_mask is not None:
raise NotImplementedError
# Compute scaled dot-product attention
dim_per_head = self.dim // self.n_heads
mh_q = q.split(
dim_per_head,
dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
mh_k = k.transpose(1, 3).split(
dim_per_head,
dim=3) # (bs, max_seq_length, 1, dim_per_head) * n_heads
mh_v = v.split(
dim_per_head,
dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
normalize_factor = float(dim_per_head)**-0.5
attn_weights = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki]) * normalize_factor
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_heads):
attn_weights[head_idx] = attn_weights[head_idx] + mask
attn_weights = [aw.softmax(dim=1) for aw in attn_weights
] # (bs, max_seq_length, 1, max_seq_length) * n_heads
attn = [
torch.einsum('bkhq,bchk->bchq', wi, vi)
for wi, vi in zip(attn_weights, mh_v)
] # (bs, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
attn = self.out_lin(attn)
if output_attentions:
return attn, attn_weights.cat(dim=2)
else:
return (attn, )
class FFN(modeling_distilbert.FFN):
""" FFN module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
self.seq_len_dim = 3
setattr(
self, 'lin1',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.hidden_dim,
kernel_size=1,
))
setattr(
self, 'lin2',
nn.Conv2d(
in_channels=config.hidden_dim,
out_channels=config.dim,
kernel_size=1,
))
class TransformerBlock(modeling_distilbert.TransformerBlock):
def __init__(self, config):
super().__init__(config)
setattr(self, 'attention', MultiHeadSelfAttention(config))
setattr(self, 'sa_layer_norm', LayerNormANE(config.dim, eps=EPS))
setattr(self, 'ffn', FFN(config))
setattr(self, 'output_layer_norm', LayerNormANE(config.dim, eps=EPS))
class Transformer(modeling_distilbert.Transformer):
def __init__(self, config):
super().__init__(config)
setattr(
self, 'layer',
nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layers)]))
class DistilBertModel(modeling_distilbert.DistilBertModel):
def __init__(self, config):
super().__init__(config)
setattr(self, 'embeddings', Embeddings(config))
setattr(self, 'transformer', Transformer(config))
# Register hook for unsqueezing nn.Linear parameters to match nn.Conv2d parameter spec
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
class DistilBertForMaskedLM(modeling_distilbert.DistilBertForMaskedLM):
def __init__(self, config):
super().__init__(config)
from transformers.activations import get_activation
setattr(self, 'activation', get_activation(config.activation))
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'vocab_transform', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'vocab_layer_norm', LayerNormANE(config.dim, eps=EPS))
setattr(self, 'vocab_projector',
nn.Conv2d(config.dim, config.vocab_size, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_states = dlbrt_output[0] # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_transform(
hidden_states) # (bs, dim, 1, seq_len)
prediction_logits = self.activation(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_layer_norm(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_projector(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = prediction_logits.squeeze(-1).squeeze(
-1) # (bs, dim)
output = (prediction_logits, ) + dlbrt_output[1:]
mlm_loss = None
return ((mlm_loss, ) + output) if mlm_loss is not None else output
class DistilBertForSequenceClassification(
modeling_distilbert.DistilBertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'classifier', nn.Conv2d(config.dim, config.num_labels,
1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if labels is not None or self.training:
raise NotImplementedError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_state = distilbert_output[0] # (bs, dim, 1, seq_len)
pooled_output = hidden_state[:, :, :, 0:1] # (bs, dim, 1, 1)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim, 1, 1)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim, 1, 1)
logits = self.classifier(pooled_output) # (bs, num_labels, 1, 1)
logits = logits.squeeze(-1).squeeze(-1) # (bs, num_labels)
output = (logits, ) + distilbert_output[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
class DistilBertForQuestionAnswering(
modeling_distilbert.DistilBertForQuestionAnswering):
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'qa_outputs', nn.Conv2d(config.dim, config.num_labels,
1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or start_positions is not None or end_positions is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_states = distilbert_output[0] # (bs, dim, 1, max_query_len)
hidden_states = self.dropout(
hidden_states) # (bs, dim, 1, max_query_len)
logits = self.qa_outputs(hidden_states) # (bs, 2, 1, max_query_len)
start_logits, end_logits = logits.split(
1, dim=1) # (bs, 1, 1, max_query_len) * 2
start_logits = start_logits.squeeze().contiguous(
) # (bs, max_query_len)
end_logits = end_logits.squeeze().contiguous() # (bs, max_query_len)
output = (start_logits, end_logits) + distilbert_output[1:]
total_loss = None
return ((total_loss, ) + output) if total_loss is not None else output
class DistilBertForTokenClassification(
modeling_distilbert.DistilBertForTokenClassification):
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'classifier',
nn.Conv2d(config.hidden_size, config.num_labels, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
sequence_output = outputs[0] # (bs, dim, 1, seq_len)
logits = self.classifier(
sequence_output) # (bs, num_labels, 1, seq_len)
logits = logits.squeeze(2).transpose(1, 2) # (bs, seq_len, num_labels)
output = (logits, ) + outputs[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
class DistilBertForMultipleChoice(
modeling_distilbert.DistilBertForMultipleChoice):
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'classifier', nn.Conv2d(config.dim, 1, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
num_choices = input_ids.shape[
1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(
-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(
-1,
attention_mask.size(-1)) if attention_mask is not None else None
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2),
inputs_embeds.size(-1))
if inputs_embeds is not None else None)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_state = outputs[0] # (bs * num_choices, dim, 1, seq_len)
pooled_output = hidden_state[:, :, :,
0:1] # (bs * num_choices, dim, 1, 1)
pooled_output = self.pre_classifier(
pooled_output) # (bs * num_choices, dim, 1, 1)
pooled_output = nn.ReLU()(
pooled_output) # (bs * num_choices, dim, 1, 1)
logits = self.classifier(pooled_output) # (bs * num_choices, 1, 1, 1)
logits = logits.squeeze() # (bs * num_choices)
reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
output = (reshaped_logits, ) + outputs[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_internal_proj = all(substr in k for substr in ['lin', '.weight'])
is_output_proj = all(substr in k
for substr in ['classifier', '.weight'])
if is_internal_proj or is_output_proj:
if len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]