@@ -117,7 +117,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
117
117
}
118
118
119
119
if (hidden_size % num_heads_ != 0 ) {
120
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " hidden_size should be divisiable by num_heads." );
120
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " hidden_size should be divisible by num_heads." );
121
121
}
122
122
} else {
123
123
int qkv_sizes = 0 ;
@@ -129,12 +129,13 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
129
129
130
130
if (qkv_hidden_sizes_[0 ] != qkv_hidden_sizes_[1 ]) {
131
131
return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
132
- " qkv_hidden_sizes first element should be same as the second" );
132
+ " qkv_hidden_sizes first element should be same as the second" );
133
133
}
134
134
135
135
for (size_t i = 0 ; i < qkv_hidden_sizes_.size (); i++) {
136
136
if (qkv_hidden_sizes_[i] % num_heads_ != 0 ) {
137
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " hidden_size should be divisiable by num_heads:" , qkv_hidden_sizes_[i]);
137
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
138
+ " hidden_size should be divisible by num_heads:" , qkv_hidden_sizes_[i]);
138
139
}
139
140
140
141
qkv_sizes += static_cast <int >(qkv_hidden_sizes_[i]);
@@ -164,13 +165,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
164
165
return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'past' dimension 0 shall have length of 2" );
165
166
}
166
167
if (static_cast <int >(past_dims[1 ]) != batch_size) {
167
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0" );
168
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
169
+ " Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0" );
168
170
}
169
171
if (static_cast <int >(past_dims[2 ]) != num_heads_) {
170
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'past' dimension 2 shall have length of num_heads" , num_heads_);
172
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
173
+ " Inputs 'past' dimension 2 shall have length of num_heads" , num_heads_);
171
174
}
172
175
if (static_cast <int >(past_dims[4 ]) != hidden_size / num_heads_) {
173
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'past' dimension 2 shall have length of " , hidden_size / num_heads_);
176
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
177
+ " Inputs 'past' dimension 2 shall have length of " , hidden_size / num_heads_);
174
178
}
175
179
past_sequence_length = static_cast <int >(past_dims[3 ]);
176
180
}
@@ -179,31 +183,50 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
179
183
const auto & mask_dims = mask_index->Shape ().GetDims ();
180
184
if (mask_dims.size () == 1 ) {
181
185
if (static_cast <int >(mask_dims[0 ]) != batch_size && static_cast <int >(mask_dims[0 ]) != 2 * batch_size) {
182
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size" );
186
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
187
+ " Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size" );
183
188
}
184
189
} else if (mask_dims.size () == 2 ) {
185
- if (static_cast <int >(mask_dims[0 ]) != batch_size || static_cast <int >(mask_dims[1 ]) != past_sequence_length + sequence_length) {
190
+ if (static_cast <int >(mask_dims[0 ]) != batch_size ||
191
+ static_cast <int >(mask_dims[1 ]) != past_sequence_length + sequence_length) {
186
192
// Add operator supports broadcasting. Here we handle a case with only one element in the 2nd dimension.
187
- if ((static_cast <int >(mask_dims[0 ]) == batch_size || static_cast <int >(mask_dims[0 ]) == 1 ) && static_cast <int >(mask_dims[1 ]) == 1 ) {
188
- // Mask will have same value after propogation, which has same effect as no mask.
193
+ if ((static_cast <int >(mask_dims[0 ]) == batch_size || static_cast <int >(mask_dims[0 ]) == 1 ) &&
194
+ static_cast <int >(mask_dims[1 ]) == 1 ) {
195
+ // Mask will have same value after propagation, which has same effect as no mask.
189
196
mask_index = nullptr ;
190
197
} else {
191
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)" );
198
+ return ORT_MAKE_STATUS (
199
+ ONNXRUNTIME, INVALID_ARGUMENT,
200
+ " Inputs 'mask_index' with 2D data shall have shape "
201
+ " batch_size x (past_sequence_length + sequence_length)" );
192
202
}
193
203
}
194
204
} else if (mask_dims.size () == 3 ) {
195
- if (static_cast <int >(mask_dims[0 ]) != batch_size || mask_dims[1 ] != sequence_length || static_cast <int >(mask_dims[2 ]) != past_sequence_length + sequence_length) {
196
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)" );
205
+ if (static_cast <int >(mask_dims[0 ]) != batch_size ||
206
+ mask_dims[1 ] != sequence_length ||
207
+ static_cast <int >(mask_dims[2 ]) != past_sequence_length + sequence_length) {
208
+ return ORT_MAKE_STATUS (
209
+ ONNXRUNTIME, INVALID_ARGUMENT,
210
+ " Inputs 'mask_index' with 3D data shall have shape "
211
+ " batch_size x sequence_length x (past_sequence_length + sequence_length)" );
197
212
}
198
213
} else if (mask_dims.size () == 4 ) {
199
- if (static_cast <int >(mask_dims[0 ]) != batch_size || mask_dims[1 ] != 1 || mask_dims[2 ] != mask_dims[3 ] || mask_dims[2 ] < static_cast <int64_t >(past_sequence_length) + sequence_length) {
200
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'mask_index' with 4D data shall have shape batch_size x 1 x max_sequence_length x max_sequence_length)" );
214
+ if (static_cast <int >(mask_dims[0 ]) != batch_size ||
215
+ mask_dims[1 ] != 1 ||
216
+ mask_dims[2 ] != mask_dims[3 ] ||
217
+ mask_dims[2 ] < static_cast <int64_t >(past_sequence_length) + sequence_length) {
218
+ return ORT_MAKE_STATUS (
219
+ ONNXRUNTIME, INVALID_ARGUMENT,
220
+ " Inputs 'mask_index' with 4D data shall have shape "
221
+ " batch_size x 1 x max_sequence_length x max_sequence_length)" );
201
222
}
202
223
if (is_unidirectional_ == true ) {
203
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false" );
224
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
225
+ " Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false" );
204
226
}
205
227
} else {
206
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got " ,
228
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
229
+ " Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got " ,
207
230
mask_dims.size ());
208
231
}
209
232
}
@@ -212,24 +235,29 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
212
235
const auto & extra_add_qk_dims = extra_add_qk->Shape ().GetDims ();
213
236
214
237
if (extra_add_qk_dims.size () != 4 ) {
215
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'extra_add_qk' is expected to have 4 dimensions, got " ,
238
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
239
+ " Input 'extra_add_qk' is expected to have 4 dimensions, got " ,
216
240
extra_add_qk_dims.size ());
217
241
}
218
242
219
243
if (extra_add_qk_dims[0 ] != batch_size) {
220
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'extra_add_qk' dimension 0 should be same as batch_size, got " ,
244
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
245
+ " Input 'extra_add_qk' dimension 0 should be same as batch_size, got " ,
221
246
extra_add_qk_dims[0 ]);
222
247
}
223
248
if (extra_add_qk_dims[1 ] != num_heads_) {
224
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'extra_add_qk' dimension 1 should be same as number of heads, got " ,
249
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
250
+ " Input 'extra_add_qk' dimension 1 should be same as number of heads, got " ,
225
251
extra_add_qk_dims[1 ]);
226
252
}
227
253
if (extra_add_qk_dims[2 ] != sequence_length) {
228
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'extra_add_qk' dimension 2 should be same as sequence_length, got " ,
254
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
255
+ " Input 'extra_add_qk' dimension 2 should be same as sequence_length, got " ,
229
256
extra_add_qk_dims[2 ]);
230
257
}
231
258
if (extra_add_qk_dims[3 ] != sequence_length) {
232
- return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'extra_add_qk' dimension 3 should be same as sequence_length, got " ,
259
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
260
+ " Input 'extra_add_qk' dimension 3 should be same as sequence_length, got " ,
233
261
extra_add_qk_dims[3 ]);
234
262
}
235
263
}
@@ -322,7 +350,6 @@ template <typename T>
322
350
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
323
351
/* out*/ bool & is_packed,
324
352
/* out*/ PrePackedWeights* prepacked_weights) {
325
-
326
353
/* The PrePack() massages the weights to speed up Compute(), there is an option to
327
354
* use shared prepacked weights in which case prepacked_weights parameter would be non-null.
328
355
*
@@ -375,9 +402,14 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
375
402
const size_t qkv_head_size[3 ] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
376
403
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
377
404
378
- if (!IsPackWeightsSuccessful (0 , alloc, qkv_head_size[0 ], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
379
- !IsPackWeightsSuccessful (1 , alloc, qkv_head_size[1 ], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0 ]), weight_matrix_col_size, prepacked_weights) ||
380
- !IsPackWeightsSuccessful (2 , alloc, qkv_head_size[2 ], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0 ] + qkv_head_size[1 ])), weight_matrix_col_size, prepacked_weights)) {
405
+ if (!IsPackWeightsSuccessful (0 , alloc, qkv_head_size[0 ], input_hidden_size,
406
+ weights_data, weight_matrix_col_size, prepacked_weights) ||
407
+ !IsPackWeightsSuccessful (1 , alloc, qkv_head_size[1 ], input_hidden_size,
408
+ weights_data + (num_heads_ * qkv_head_size[0 ]),
409
+ weight_matrix_col_size, prepacked_weights) ||
410
+ !IsPackWeightsSuccessful (2 , alloc, qkv_head_size[2 ], input_hidden_size,
411
+ weights_data + (num_heads_ * (qkv_head_size[0 ] + qkv_head_size[1 ])),
412
+ weight_matrix_col_size, prepacked_weights)) {
381
413
if (prepacked_weights == nullptr ) {
382
414
FreePackedWeights (packed_weights_, qkv_hidden_sizes_.size ());
383
415
}
@@ -469,7 +501,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
469
501
// gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT)
470
502
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
471
503
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
472
- auto gemm_data = allocator->Alloc (SafeInt<size_t >(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
504
+ int qkv_hidden_size = (q_hidden_size + k_hidden_size + v_hidden_size);
505
+ auto gemm_data = allocator->Alloc (SafeInt<size_t >(batch_size) * sequence_length * qkv_hidden_size * element_size);
473
506
BufferUniquePtr gemm_buffer (gemm_data, BufferDeleter (std::move (allocator)));
474
507
475
508
auto Q = reinterpret_cast <T*>(gemm_data);
@@ -523,12 +556,13 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
523
556
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
524
557
if (is_prepack_) {
525
558
uint8_t * packed_weight;
526
- packed_weight = static_cast <uint8_t *>(packed_weights_[qkv_index].get ()) + packed_weights_size_[qkv_index] * (weights_offset / head_size);
559
+ packed_weight = static_cast <uint8_t *>(packed_weights_[qkv_index].get ()) +
560
+ packed_weights_size_[qkv_index] * (weights_offset / head_size);
527
561
528
562
MlasGemm (
529
563
CblasNoTrans, // TransA = no
530
564
sequence_length, // M = S
531
- head_size, // N = H
565
+ head_size, // N = H
532
566
input_hidden_size, // K = D
533
567
1 .0f , // alpha
534
568
input_data + input_offset, // A
@@ -540,20 +574,20 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
540
574
nullptr ); // use single-thread
541
575
} else {
542
576
math::GemmEx<float , ThreadPool>(
543
- CblasNoTrans, // TransA = no
544
- CblasNoTrans, // TransB = no
545
- sequence_length, // M = S
546
- head_size, // N = H
547
- input_hidden_size, // K = D
548
- 1 .0f , // alpha
549
- input_data + input_offset, // A
550
- input_hidden_size, // lda = D
551
- weights_data + weights_offset, // B
552
- q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3
553
- 1 .0f , // beta
554
- qkv_dest + qkv_offset, // C
555
- head_size, // ldc
556
- nullptr // use single-thread
577
+ CblasNoTrans, // TransA = no
578
+ CblasNoTrans, // TransB = no
579
+ sequence_length, // M = S
580
+ head_size, // N = H
581
+ input_hidden_size, // K = D
582
+ 1 .0f , // alpha
583
+ input_data + input_offset, // A
584
+ input_hidden_size, // lda = D
585
+ weights_data + weights_offset, // B
586
+ q_hidden_size + k_hidden_size + v_hidden_size, // ldb = NH1 + NH2 + NH3
587
+ 1 .0f , // beta
588
+ qkv_dest + qkv_offset, // C
589
+ head_size, // ldc
590
+ nullptr // use single-thread
557
591
);
558
592
}
559
593
}
0 commit comments