|
3 | 3 |
|
4 | 4 | #include "core/providers/qnn/builder/qnn_def.h"
|
5 | 5 | #include "core/providers/qnn/builder/qnn_utils.h"
|
| 6 | +#include <functional> |
6 | 7 | #include <memory>
|
7 | 8 | #include <ostream>
|
8 | 9 | #include <cstring>
|
@@ -432,6 +433,15 @@ Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_Quan
|
432 | 433 | return Status::OK();
|
433 | 434 | }
|
434 | 435 |
|
| 436 | +uint32_t CalcQnnTensorNumElems(const Qnn_Tensor_t& qnn_tensor) { |
| 437 | + uint32_t* qnn_tensor_dims = GetQnnTensorDims(qnn_tensor); |
| 438 | + uint32_t qnn_tensor_rank = GetQnnTensorRank(qnn_tensor); |
| 439 | + return std::accumulate(qnn_tensor_dims, |
| 440 | + qnn_tensor_dims + qnn_tensor_rank, |
| 441 | + 1, |
| 442 | + std::multiplies<uint32_t>()); |
| 443 | +} |
| 444 | + |
435 | 445 | bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface,
|
436 | 446 | const Qnn_GraphHandle_t& graph,
|
437 | 447 | const std::string& node_name,
|
@@ -466,12 +476,7 @@ bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface,
|
466 | 476 | return false;
|
467 | 477 | }
|
468 | 478 | // verify size expressed by the dims matches the raw tensor size
|
469 |
| - auto qnn_tensor_dims = GetQnnTensorDims(qnn_tensor); |
470 |
| - auto qnn_tensor_rank = GetQnnTensorRank(qnn_tensor); |
471 |
| - uint32_t qnn_tensor_size = std::accumulate(qnn_tensor_dims, |
472 |
| - qnn_tensor_dims + qnn_tensor_rank, |
473 |
| - static_cast<uint32_t>(data_size), |
474 |
| - std::multiplies<uint32_t>()); |
| 479 | + uint32_t qnn_tensor_size = CalcQnnTensorNumElems(qnn_tensor) * gsl::narrow_cast<uint32_t>(data_size); |
475 | 480 | auto qnn_tensor_buf_size = GetQnnTensorClientBuf(qnn_tensor).dataSize;
|
476 | 481 | if (qnn_tensor_size != qnn_tensor_buf_size) {
|
477 | 482 | ss << "Data length mismatch for static tensor. node_name: " << node_name
|
|
0 commit comments