Skip to content

Commit d93e653

Browse files
authored
Format bert or transformers code (microsoft#12646)
(1) Modify some lines to fit line length limit 120 (2) Adjust parameter order of LaunchAttentionKernel (3) Format code with Clang-Format in VS Code (4) Fix spelling errors
1 parent dc486d1 commit d93e653

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1371
-945
lines changed

onnxruntime/contrib_ops/cpu/bert/attention.cc

+77-43
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
117117
}
118118

119119
if (hidden_size % num_heads_ != 0) {
120-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
120+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisible by num_heads.");
121121
}
122122
} else {
123123
int qkv_sizes = 0;
@@ -129,12 +129,13 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
129129

130130
if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) {
131131
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
132-
"qkv_hidden_sizes first element should be same as the second");
132+
"qkv_hidden_sizes first element should be same as the second");
133133
}
134134

135135
for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
136136
if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
137-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]);
137+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
138+
"hidden_size should be divisible by num_heads:", qkv_hidden_sizes_[i]);
138139
}
139140

140141
qkv_sizes += static_cast<int>(qkv_hidden_sizes_[i]);
@@ -164,13 +165,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
164165
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 0 shall have length of 2");
165166
}
166167
if (static_cast<int>(past_dims[1]) != batch_size) {
167-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
168+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
169+
"Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
168170
}
169171
if (static_cast<int>(past_dims[2]) != num_heads_) {
170-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
172+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
173+
"Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
171174
}
172175
if (static_cast<int>(past_dims[4]) != hidden_size / num_heads_) {
173-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_);
176+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
177+
"Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_);
174178
}
175179
past_sequence_length = static_cast<int>(past_dims[3]);
176180
}
@@ -179,31 +183,50 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
179183
const auto& mask_dims = mask_index->Shape().GetDims();
180184
if (mask_dims.size() == 1) {
181185
if (static_cast<int>(mask_dims[0]) != batch_size && static_cast<int>(mask_dims[0]) != 2 * batch_size) {
182-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
186+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
187+
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
183188
}
184189
} else if (mask_dims.size() == 2) {
185-
if (static_cast<int>(mask_dims[0]) != batch_size || static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
190+
if (static_cast<int>(mask_dims[0]) != batch_size ||
191+
static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
186192
// Add operator supports broadcasting. Here we handle a case with only one element in the 2nd dimension.
187-
if ((static_cast<int>(mask_dims[0]) == batch_size || static_cast<int>(mask_dims[0]) == 1) && static_cast<int>(mask_dims[1]) == 1) {
188-
// Mask will have same value after propogation, which has same effect as no mask.
193+
if ((static_cast<int>(mask_dims[0]) == batch_size || static_cast<int>(mask_dims[0]) == 1) &&
194+
static_cast<int>(mask_dims[1]) == 1) {
195+
// Mask will have same value after propagation, which has same effect as no mask.
189196
mask_index = nullptr;
190197
} else {
191-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)");
198+
return ORT_MAKE_STATUS(
199+
ONNXRUNTIME, INVALID_ARGUMENT,
200+
"Inputs 'mask_index' with 2D data shall have shape "
201+
"batch_size x (past_sequence_length + sequence_length)");
192202
}
193203
}
194204
} else if (mask_dims.size() == 3) {
195-
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
196-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
205+
if (static_cast<int>(mask_dims[0]) != batch_size ||
206+
mask_dims[1] != sequence_length ||
207+
static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
208+
return ORT_MAKE_STATUS(
209+
ONNXRUNTIME, INVALID_ARGUMENT,
210+
"Inputs 'mask_index' with 3D data shall have shape "
211+
"batch_size x sequence_length x (past_sequence_length + sequence_length)");
197212
}
198213
} else if (mask_dims.size() == 4) {
199-
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] || mask_dims[2] < static_cast<int64_t>(past_sequence_length) + sequence_length) {
200-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have shape batch_size x 1 x max_sequence_length x max_sequence_length)");
214+
if (static_cast<int>(mask_dims[0]) != batch_size ||
215+
mask_dims[1] != 1 ||
216+
mask_dims[2] != mask_dims[3] ||
217+
mask_dims[2] < static_cast<int64_t>(past_sequence_length) + sequence_length) {
218+
return ORT_MAKE_STATUS(
219+
ONNXRUNTIME, INVALID_ARGUMENT,
220+
"Inputs 'mask_index' with 4D data shall have shape "
221+
"batch_size x 1 x max_sequence_length x max_sequence_length)");
201222
}
202223
if (is_unidirectional_ == true) {
203-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false");
224+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
225+
"Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false");
204226
}
205227
} else {
206-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
228+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
229+
"Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
207230
mask_dims.size());
208231
}
209232
}
@@ -212,24 +235,29 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
212235
const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims();
213236

214237
if (extra_add_qk_dims.size() != 4) {
215-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ",
238+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
239+
"Input 'extra_add_qk' is expected to have 4 dimensions, got ",
216240
extra_add_qk_dims.size());
217241
}
218242

219243
if (extra_add_qk_dims[0] != batch_size) {
220-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
244+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
245+
"Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
221246
extra_add_qk_dims[0]);
222247
}
223248
if (extra_add_qk_dims[1] != num_heads_) {
224-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
249+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
250+
"Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
225251
extra_add_qk_dims[1]);
226252
}
227253
if (extra_add_qk_dims[2] != sequence_length) {
228-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
254+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
255+
"Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
229256
extra_add_qk_dims[2]);
230257
}
231258
if (extra_add_qk_dims[3] != sequence_length) {
232-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
259+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
260+
"Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
233261
extra_add_qk_dims[3]);
234262
}
235263
}
@@ -322,7 +350,6 @@ template <typename T>
322350
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
323351
/*out*/ bool& is_packed,
324352
/*out*/ PrePackedWeights* prepacked_weights) {
325-
326353
/* The PrePack() massages the weights to speed up Compute(), there is an option to
327354
* use shared prepacked weights in which case prepacked_weights parameter would be non-null.
328355
*
@@ -375,9 +402,14 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
375402
const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
376403
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
377404

378-
if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
379-
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) ||
380-
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) {
405+
if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size,
406+
weights_data, weight_matrix_col_size, prepacked_weights) ||
407+
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size,
408+
weights_data + (num_heads_ * qkv_head_size[0]),
409+
weight_matrix_col_size, prepacked_weights) ||
410+
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size,
411+
weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])),
412+
weight_matrix_col_size, prepacked_weights)) {
381413
if (prepacked_weights == nullptr) {
382414
FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size());
383415
}
@@ -469,7 +501,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
469501
// gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT)
470502
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
471503
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
472-
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
504+
int qkv_hidden_size = (q_hidden_size + k_hidden_size + v_hidden_size);
505+
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * qkv_hidden_size * element_size);
473506
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));
474507

475508
auto Q = reinterpret_cast<T*>(gemm_data);
@@ -523,12 +556,13 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
523556
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
524557
if (is_prepack_) {
525558
uint8_t* packed_weight;
526-
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size);
559+
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) +
560+
packed_weights_size_[qkv_index] * (weights_offset / head_size);
527561

528562
MlasGemm(
529563
CblasNoTrans, // TransA = no
530564
sequence_length, // M = S
531-
head_size, // N = H
565+
head_size, // N = H
532566
input_hidden_size, // K = D
533567
1.0f, // alpha
534568
input_data + input_offset, // A
@@ -540,20 +574,20 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
540574
nullptr); // use single-thread
541575
} else {
542576
math::GemmEx<float, ThreadPool>(
543-
CblasNoTrans, // TransA = no
544-
CblasNoTrans, // TransB = no
545-
sequence_length, // M = S
546-
head_size, // N = H
547-
input_hidden_size, // K = D
548-
1.0f, // alpha
549-
input_data + input_offset, // A
550-
input_hidden_size, // lda = D
551-
weights_data + weights_offset, // B
552-
q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3
553-
1.0f, // beta
554-
qkv_dest + qkv_offset, // C
555-
head_size, // ldc
556-
nullptr // use single-thread
577+
CblasNoTrans, // TransA = no
578+
CblasNoTrans, // TransB = no
579+
sequence_length, // M = S
580+
head_size, // N = H
581+
input_hidden_size, // K = D
582+
1.0f, // alpha
583+
input_data + input_offset, // A
584+
input_hidden_size, // lda = D
585+
weights_data + weights_offset, // B
586+
q_hidden_size + k_hidden_size + v_hidden_size, // ldb = NH1 + NH2 + NH3
587+
1.0f, // beta
588+
qkv_dest + qkv_offset, // C
589+
head_size, // ldc
590+
nullptr // use single-thread
557591
);
558592
}
559593
}

onnxruntime/contrib_ops/cpu/bert/attention_base.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <vector>
67
#include "core/common/common.h"
78
#include "core/framework/op_kernel.h"
89

@@ -17,7 +18,7 @@ class AttentionBase {
1718
const TensorShape& bias_shape,
1819
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
1920
const Tensor* past,
20-
const Tensor *extra_add_qk,
21+
const Tensor* extra_add_qk,
2122
const int max_threads_per_block) const;
2223

2324
Tensor* GetPresent(OpKernelContext* context,
@@ -45,11 +46,11 @@ class AttentionBase {
4546
const TensorShape& bias_shape,
4647
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
4748
const Tensor* past,
48-
const Tensor *extra_add_qk) const;
49+
const Tensor* extra_add_qk) const;
4950

50-
int num_heads_; // number of attention heads
51-
bool is_unidirectional_; // whether every token can only attend to previous tokens.
52-
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
51+
int num_heads_; // number of attention heads
52+
bool is_unidirectional_; // whether every token can only attend to previous tokens.
53+
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
5354
};
5455

5556
} // namespace contrib

0 commit comments

Comments
 (0)