Skip to content

Commit 7762f3f

Browse files
YUNQIUGUOrachguoedgchen1
authored
[NNAPI EP] Add NNAPI Split (microsoft#18702)
### Description <!-- Describe your changes. --> As title. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> yolo-v8 model missing operator support. --------- Co-authored-by: rachguo <[email protected]> Co-authored-by: Edward Chen <[email protected]>
1 parent c4b8120 commit 7762f3f

File tree

5 files changed

+167
-12
lines changed

5 files changed

+167
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <onnx/onnx_pb.h>
5+
#include <algorithm>
6+
7+
#include "core/common/logging/logging.h"
8+
#include "core/common/safeint.h"
9+
#include "core/framework/tensorprotoutils.h"
10+
#include "core/graph/graph_viewer.h"
11+
#include "core/providers/common.h"
12+
#include "core/optimizer/initializer.h"
13+
#include "core/providers/shared/utils/utils.h"
14+
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
15+
#include "core/providers/nnapi/nnapi_builtin/builders/model_builder.h"
16+
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h"
17+
#include "core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h"
18+
#include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h"
19+
20+
using namespace android::nn::wrapper;
21+
22+
namespace onnxruntime {
23+
namespace nnapi {
24+
25+
using namespace op_builder_helpers;
26+
27+
class SplitOpBuilder : public BaseOpBuilder {
28+
// Add operator related
29+
public:
30+
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
31+
32+
private:
33+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
34+
35+
// Operator support related
36+
37+
private:
38+
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
39+
const OpSupportCheckParams& params) const override;
40+
41+
// Split opset 13- uses "split" as attribute. Currently it's not supported.
42+
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 13; }
43+
44+
// NNAPI Split is available since NNAPI feature level 3
45+
int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */,
46+
const OpSupportCheckParams& /* params */) const override {
47+
return ANEURALNETWORKS_FEATURE_LEVEL_3;
48+
}
49+
};
50+
51+
// Add operator related
52+
53+
void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
54+
const auto& input_defs = node_unit.Inputs();
55+
56+
if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) { // optional second input "split"
57+
model_builder.AddInitializerToSkip(input_defs[1].node_arg.Name());
58+
}
59+
}
60+
61+
Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
62+
const auto& input_name = node_unit.Inputs()[0].node_arg.Name();
63+
const auto& outputs = node_unit.Outputs();
64+
65+
NodeAttrHelper helper(node_unit);
66+
const auto axis = helper.Get("axis", 0);
67+
68+
int32_t num_outputs;
69+
if (node_unit.SinceVersion() >= 18) {
70+
num_outputs = SafeInt<int32_t>(*helper.GetInt("num_outputs"));
71+
} else {
72+
num_outputs = SafeInt<int32_t>(node_unit.Outputs().size());
73+
}
74+
75+
std::vector<std::string> output_names;
76+
output_names.reserve(num_outputs);
77+
for (int32_t i = 0; i < num_outputs; ++i) {
78+
output_names.push_back(outputs[i].node_arg.Name());
79+
}
80+
81+
ORT_RETURN_IF_ERROR(op_builder_helpers::AddNnapiSplit(model_builder, input_name, axis, output_names));
82+
83+
return Status::OK();
84+
}
85+
86+
// Operator support related
87+
88+
bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
89+
const OpSupportCheckParams& /* params */) const {
90+
Shape input_shape;
91+
if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape))
92+
return false;
93+
94+
const auto& input_defs = node_unit.Inputs();
95+
NodeAttrHelper helper(node_unit);
96+
const auto axis = helper.Get("axis", 0);
97+
98+
const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())];
99+
if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) {
100+
// if optional input `split` is provided
101+
auto split_initializer_it = initializers.find(input_defs[1].node_arg.Name());
102+
if (split_initializer_it == initializers.end()) {
103+
LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be initializer if provided.";
104+
return false;
105+
}
106+
const auto& splits_tensor = *split_initializer_it->second;
107+
Initializer unpacked_tensor(splits_tensor);
108+
auto splits_span = unpacked_tensor.DataAsSpan<int64_t>();
109+
uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt<uint32_t>(0));
110+
if (sum_of_splits != split_dims_at_axis) {
111+
LOGS_DEFAULT(VERBOSE) << "Sum of the 'split' input values must equal to the dim value at 'axis' specified. "
112+
<< "dim value at 'axis' specified: "
113+
<< split_dims_at_axis
114+
<< ", sum of 'split' input values: "
115+
<< sum_of_splits;
116+
return false;
117+
}
118+
119+
auto it = std::adjacent_find(splits_span.begin(), splits_span.end(), [](const auto& a, const auto& b) {
120+
return a != b;
121+
});
122+
if (it != splits_span.end()) {
123+
LOGS_DEFAULT(VERBOSE) << "NNAPI only supports the case that number of splits evenly divides split axis size";
124+
return false;
125+
}
126+
} else {
127+
uint32_t num_outputs;
128+
if (node_unit.SinceVersion() >= 18) {
129+
auto num_outputs_attr = helper.GetInt("num_outputs");
130+
if (!num_outputs_attr.has_value()) {
131+
LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
132+
return false;
133+
}
134+
num_outputs = SafeInt<uint32_t>(*num_outputs_attr);
135+
if (num_outputs != SafeInt<uint32_t>(node_unit.Outputs().size()) || num_outputs > split_dims_at_axis) {
136+
LOGS_DEFAULT(VERBOSE) << "Invalid num_outputs provided. "
137+
<< "The value should be less than or equal to the size of dimension being split "
138+
<< "and align with the size of output nodes. Current num_outputs: "
139+
<< num_outputs;
140+
return false;
141+
}
142+
} else {
143+
num_outputs = SafeInt<uint32_t>(node_unit.Outputs().size());
144+
}
145+
// NNAPI only supports the case where axis can be evenly divided by num of splits
146+
if (split_dims_at_axis % num_outputs != 0) {
147+
LOGS_DEFAULT(VERBOSE) << "split count: " << num_outputs << " doesn't evenly divide split dimension: "
148+
<< split_dims_at_axis;
149+
return false;
150+
}
151+
}
152+
return true;
153+
}
154+
155+
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
156+
op_registrations.builders.push_back(std::make_unique<SplitOpBuilder>());
157+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
158+
}
159+
160+
} // namespace nnapi
161+
} // namespace onnxruntime

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.cc

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
3232
CreateResizeOpBuilder("Resize", op_registrations);
3333
CreateSliceOpBuilder("Slice", op_registrations);
3434
CreateSoftMaxOpBuilder("Softmax", op_registrations);
35+
CreateSplitOpBuilder("Split", op_registrations);
3536
CreateSqueezeOpBuilder("Squeeze", op_registrations);
3637
CreateTransposeOpBuilder("Transpose", op_registrations);
3738
CreateUnsqueezeOpBuilder("Unsqueeze", op_registrations);

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_factory.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void CreateReluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
3333
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3434
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3535
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
36+
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3637
void CreateSoftMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3738
void CreateSqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3839
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

onnxruntime/test/providers/cpu/tensor/split_op_test.cc

+3-12
Original file line numberDiff line numberDiff line change
@@ -706,9 +706,8 @@ TEST(SplitOperatorTest, Split18_NumOutputs_EvenSplit) {
706706
7.f, 8.f}});
707707

708708
int64_t num_outputs = 2;
709-
#ifdef USE_COREML
709+
710710
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, true);
711-
#endif
712711
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false);
713712
}
714713

@@ -735,9 +734,8 @@ TEST(SplitOperatorTest, Split18_NumOutputs_UnevenSplit) {
735734
outputs.push_back({{1, 2}, {9.f, 10.f}});
736735

737736
int64_t num_outputs = 3;
738-
#ifdef USE_COREML
737+
739738
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, true);
740-
#endif
741739
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
742740
}
743741

@@ -763,10 +761,8 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
763761
};
764762
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
765763
"Attribute `num_outputs` value cannot be lower than 1");
766-
#ifdef USE_COREML
767764
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true,
768765
"Attribute `num_outputs` value cannot be lower than 1");
769-
#endif
770766

771767
outputs.clear();
772768
outputs.push_back({{1, 2},
@@ -775,12 +771,11 @@ TEST(SplitOperatorTest, Split18_InvalidNumOutputs) {
775771
{0.f, 0.f}});
776772

777773
num_outputs = 3;
774+
778775
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, false,
779776
"Invalid num_outputs value of 3. Size of dimension being split is 2");
780-
#ifdef USE_COREML
781777
RunTest<float>(axis, {}, input, outputs, excluded_providers, true, true, num_outputs, true,
782778
"Invalid num_outputs value of 3. Size of dimension being split is 2");
783-
#endif
784779
}
785780

786781
TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) {
@@ -798,9 +793,7 @@ TEST(SplitOperatorTest, Split18_NumOutputsEvenSplitAxis1) {
798793

799794
int64_t num_outputs = 3;
800795
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs, false);
801-
#ifdef USE_COREML
802796
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true, num_outputs);
803-
#endif
804797
}
805798

806799
TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) {
@@ -818,9 +811,7 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) {
818811
outputs.push_back({{2, 1}, {3.f, 6.f}});
819812

820813
int64_t num_outputs = 2;
821-
#ifdef USE_COREML
822814
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs);
823-
#endif
824815
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
825816
}
826817

tools/ci_build/github/android/nnapi_supported_ops.md

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Keep in sync with doco generated from /docs/execution-providers/NNAPI-ExecutionP
4545
|ai.onnx:Sin||
4646
|ai.onnx:Slice||
4747
|ai.onnx:Softmax||
48+
|ai.onnx:Split|Number of splits must evenly divide split axis size. Input split should be constant if provided.|
4849
|ai.onnx:Sqrt||
4950
|ai.onnx:Squeeze|Input axes should be constant.|
5051
|ai.onnx:Sub||

0 commit comments

Comments
 (0)