Skip to content

Commit f39a8d1

Browse files
authored
allow users to set graph inputs and outputs fully. (#905)
* allow users to set graph inputs and outputs fully. * update * update the comments of the APIs * update * remove commented-out codes. * fix test failures. * fix comments. * adding more check to throw not support exception right now.
1 parent bb58806 commit f39a8d1

File tree

5 files changed

+90
-102
lines changed

5 files changed

+90
-102
lines changed

include/onnxruntime/core/graph/graph.h

+8-24
Original file line numberDiff line numberDiff line change
@@ -727,27 +727,15 @@ class Graph {
727727
ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name));
728728
}
729729

730-
/** When programmatically constructing a Graph, explicitly set the order to use for graph inputs when the graph is
731-
resolved.
732-
This will determine the graph input order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
733-
@param inputs NodeArgs that represent graph inputs which need to be explicitly ordered.
734-
Any graph inputs not in this list will be appended to the ordered graph input list, in the order that they were first
735-
used by Nodes (i.e. the order of Node creation implicitly determines the ordering).
730+
/** When programmatically constructing a Graph, explicitly set graph inputs.
731+
@param inputs NodeArgs that represent complete graph inputs which need to be explicitly ordered.
736732
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
737-
void SetInputOrder(const std::vector<const NodeArg*> inputs) {
738-
graph_input_order_ = inputs;
739-
}
733+
void SetInputs(const std::vector<const NodeArg*> inputs);
740734

741-
/** When programmatically constructing a Graph, explicitly set the order to use for graph outputs when the graph is
742-
resolved.
743-
This will determine the graph output order when the Graph is converted to a GraphProto by Graph::ToGraphProto.
744-
@param outputs NodeArgs that represent graph outputs which need to be explicitly ordered.
745-
Any graph outputs not in this list will be appended to the ordered graph output list, in the order that they were first
746-
produced by Nodes (i.e. the order of Node creation implicitly determines the ordering).
735+
/** When programmatically constructing a Graph, explicitly set graph outputs.
736+
@param outputs NodeArgs that represent complete graph outputs which need to be explicitly ordered.
747737
@remarks If the Graph was loaded from a GraphProto this has no effect.*/
748-
void SetOutputOrder(const std::vector<const NodeArg*> outputs) {
749-
graph_output_order_ = outputs;
750-
}
738+
void SetOutputs(const std::vector<const NodeArg*> outputs);
751739

752740
/** Returns true if this is a subgraph or fase if it is a high-level graph. */
753741
bool IsSubgraph() const { return parent_graph_ != nullptr; }
@@ -945,12 +933,14 @@ class Graph {
945933

946934
// Full list of graph inputs. Matches number and order of inputs in the GraphProto.
947935
std::vector<const NodeArg*> graph_inputs_including_initializers_;
936+
bool graph_inputs_manually_set_ = false;
948937

949938
// Graph inputs excluding initializers.
950939
std::vector<const NodeArg*> graph_inputs_excluding_initializers_;
951940

952941
// Graph outputs.
953942
std::vector<const NodeArg*> graph_outputs_;
943+
bool graph_outputs_manually_set_ = false;
954944

955945
// Graph value_info.
956946
std::vector<const NodeArg*> value_info_;
@@ -975,12 +965,6 @@ class Graph {
975965
// NodeArgs that come from outer scope. Used when building a graph so that
976966
// these don't get recorded as graph inputs in the GraphProto.
977967
std::unordered_set<std::string> outer_scope_node_arg_names_;
978-
979-
// Explicit graph input order to be used when constructing a Graph manually.
980-
std::vector<const NodeArg*> graph_input_order_;
981-
982-
// Explicit graph output order to be used when constructing a Graph manually.
983-
std::vector<const NodeArg*> graph_output_order_;
984968
};
985969

986970
} // namespace onnxruntime

onnxruntime/core/graph/graph.cc

+64-51
Original file line numberDiff line numberDiff line change
@@ -940,11 +940,8 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
940940
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
941941
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
942942
if (!loaded_from_model_file) {
943-
auto in_ordered_graph_outputs = find(graph_output_order_.cbegin(), graph_output_order_.cend(), node_arg);
944-
if (in_ordered_graph_outputs == graph_output_order_.cend()) {
945-
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
946-
graph_outputs_.end());
947-
}
943+
graph_outputs_.erase(std::remove(graph_outputs_.begin(), graph_outputs_.end(), node_arg),
944+
graph_outputs_.end());
948945
}
949946
}
950947
}
@@ -2219,10 +2216,8 @@ void Graph::CleanUnusedInitializers() {
22192216

22202217
GSL_SUPPRESS(es .84) // warning about ignoring return value from insert(...)
22212218
Status Graph::SetGraphInputsOutputs() {
2222-
// Reset graph inputs/outputs/value info state.
2219+
// Reset graph inputs excluding initializers/value_info.
22232220
graph_inputs_excluding_initializers_.clear();
2224-
graph_inputs_including_initializers_.clear();
2225-
graph_outputs_.clear();
22262221
value_info_.clear();
22272222

22282223
// Flag indicates that this graph is loaded from model file.
@@ -2231,10 +2226,11 @@ Status Graph::SetGraphInputsOutputs() {
22312226
// and outputs will be inferred.
22322227
const bool loaded_from_model_file = GraphLoadedFromModelFile(graph_proto_);
22332228

2234-
// if something is coming from outer scope, consider it already added
2235-
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
2236-
22372229
if (loaded_from_model_file) {
2230+
// Reset graph inputs/outputs.
2231+
graph_inputs_including_initializers_.clear();
2232+
graph_outputs_.clear();
2233+
22382234
// Name to NodeArg mapping of all graph initializers.
22392235
std::unordered_map<std::string, const NodeArg*> graph_initializers;
22402236

@@ -2302,49 +2298,31 @@ Status Graph::SetGraphInputsOutputs() {
23022298
}
23032299

23042300
} else {
2305-
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
2306-
std::vector<std::string> ordered_output_names;
2301+
std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
2302+
std::vector<const NodeArg*> output_node_args_in_order;
23072303

2308-
// add any explicitly ordered inputs
2309-
for (auto* node_arg : graph_input_order_) {
2310-
if (!node_arg || !node_arg->Exists()) {
2311-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered inputs");
2312-
}
2313-
2314-
added_input_names.insert(node_arg->Name());
2315-
graph_inputs_including_initializers_.push_back(node_arg);
2316-
if (name_to_initial_tensor_.find(node_arg->Name()) == name_to_initial_tensor_.end()) {
2317-
graph_inputs_excluding_initializers_.push_back(node_arg);
2318-
}
2304+
// if something is coming from outer scope, consider it already added
2305+
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
2306+
if (!graph_inputs_manually_set_) {
2307+
graph_inputs_including_initializers_.clear();
23192308
}
23202309

2321-
// add any explicitly ordered outputs
2322-
for (auto* node_arg : graph_output_order_) {
2323-
if (!node_arg || !node_arg->Exists()) {
2324-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid entry in explicitly ordered outputs");
2325-
}
2326-
output_name_to_node_arg.insert({node_arg->Name(), node_arg});
2327-
ordered_output_names.push_back(node_arg->Name());
2310+
if (!graph_outputs_manually_set_) {
2311+
graph_outputs_.clear();
23282312
}
23292313

2330-
// add all other outputs
2314+
// Collect all nodes' outputs
23312315
for (const auto& node : Nodes()) {
23322316
for (const auto* output_def : node.OutputDefs()) {
23332317
if (output_def->Exists()) {
2334-
auto& name = output_def->Name();
2335-
// check it wasn't in the explicitly ordered outputs
2336-
if (output_name_to_node_arg.find(name) == output_name_to_node_arg.cend()) {
2337-
output_name_to_node_arg.insert({name, output_def});
2338-
ordered_output_names.push_back(name);
2339-
}
2318+
output_node_args_in_order.push_back(output_def);
2319+
output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
23402320
}
23412321
}
23422322
}
23432323

23442324
// Init graph output args with copy of all node output args.
2345-
auto graph_output_args = output_name_to_node_arg;
2346-
std::unordered_set<Node*> inner_nodes;
2347-
2325+
auto graph_output_args = output_name_to_node_arg_index;
23482326
for (const auto& node : Nodes()) {
23492327
// Go thru all node's inputs.
23502328
for (const auto* input_arg : node.InputDefs()) {
@@ -2353,15 +2331,28 @@ Status Graph::SetGraphInputsOutputs() {
23532331
continue;
23542332
}
23552333

2356-
auto output_arg_iter = output_name_to_node_arg.find(input_arg->Name());
2357-
if (output_name_to_node_arg.end() == output_arg_iter) {
2334+
auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());
2335+
if (output_name_to_node_arg_index.end() == output_arg_iter) {
23582336
// This input arg should be fed when running evaluation.
23592337
// it should be a graph input.
23602338
const std::string& name = input_arg->Name();
23612339
if (added_input_names.end() == added_input_names.find(name)) {
23622340
// This graph input has not been added into <graph_inputs_>.
2363-
graph_inputs_including_initializers_.push_back(input_arg);
2364-
2341+
if (!graph_inputs_manually_set_) {
2342+
graph_inputs_including_initializers_.push_back(input_arg);
2343+
} else {
2344+
// Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
2345+
auto& inputs = GetInputsIncludingInitializers();
2346+
auto iter = std::find(inputs.begin(), inputs.end(), input_arg);
2347+
if (inputs.end() == iter) {
2348+
// it's not in graph inputs.
2349+
auto initializers = GetAllInitializedTensors();
2350+
if (initializers.end() == initializers.find(input_arg->Name())) {
2351+
// It's not in graph initializers.
2352+
return Status(ONNXRUNTIME, FAIL, input_arg->Name() + " must be either specified in graph inputs or graph initailizers.");
2353+
}
2354+
}
2355+
}
23652356
if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) {
23662357
graph_inputs_excluding_initializers_.push_back(input_arg);
23672358
}
@@ -2377,12 +2368,15 @@ Status Graph::SetGraphInputsOutputs() {
23772368
}
23782369
}
23792370

2380-
// Set graph outputs
2381-
auto end = graph_output_args.end();
2382-
for (auto& name : ordered_output_names) {
2383-
auto graph_output = graph_output_args.find(name);
2384-
if (graph_output != end) {
2385-
graph_outputs_.push_back(graph_output->second);
2371+
if (!graph_outputs_manually_set_) {
2372+
// Set graph outputs in order.
2373+
std::vector<size_t> graph_output_args_index;
2374+
for (auto output_arg : graph_output_args) {
2375+
graph_output_args_index.push_back(output_arg.second);
2376+
}
2377+
std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
2378+
for (auto& output_arg_index : graph_output_args_index) {
2379+
graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
23862380
}
23872381
}
23882382
}
@@ -2483,6 +2477,25 @@ Status Graph::InlineFunction(Node& node) {
24832477
return Status::OK();
24842478
}
24852479

2480+
void Graph::SetInputs(const std::vector<const NodeArg*> inputs) {
2481+
if (GraphLoadedFromModelFile(graph_proto_)) {
2482+
// TODO: add this support.
2483+
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
2484+
}
2485+
2486+
graph_inputs_including_initializers_ = inputs;
2487+
graph_inputs_manually_set_ = true;
2488+
}
2489+
2490+
void Graph::SetOutputs(const std::vector<const NodeArg*> outputs) {
2491+
if (GraphLoadedFromModelFile(graph_proto_)) {
2492+
// TODO: add this support.
2493+
ORT_THROW("This API is not supported when model is loaded from proto file right now.");
2494+
}
2495+
graph_outputs_ = outputs;
2496+
graph_outputs_manually_set_ = true;
2497+
}
2498+
24862499
void Graph::AddFunction(const ONNX_NAMESPACE::FunctionProto* func_proto) {
24872500
this->model_functions_[func_proto->name()] = func_proto;
24882501
}

onnxruntime/test/ir/graph_test.cc

+10-19
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,6 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
443443

444444
for (auto i = 0; i < 20; ++i) {
445445
map.insert({std::to_string(i), i});
446-
447-
std::cout << "Insert " << i << "\n";
448-
for (auto pair : map) {
449-
std::cout << pair.first << ":" << pair.second << " ";
450-
}
451-
std::cout << "\n";
452446
}
453447

454448
// | |
@@ -458,10 +452,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
458452
// |
459453
// d (Split)
460454
// / \
461-
// 1 .. 10
462-
std::unordered_map<std::string, std::pair<std::vector<NodeArg*>, std::vector<NodeArg*>>>
463-
expected_node_name_to_input_output_args;
464-
455+
// 1 .. 10
465456
TypeProto tensor_int32;
466457
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
467458
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
@@ -475,37 +466,36 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
475466
auto& output_arg_c = graph.GetOrCreateNodeArg("node_c_out_1", &tensor_int32);
476467

477468
std::vector<NodeArg*> split_outputs;
469+
std::vector<const NodeArg*> graph_outputs;
478470
for (int i = 0; i < 10; ++i) {
479-
split_outputs.push_back(&graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32));
471+
auto arg = &graph.GetOrCreateNodeArg("node_d_out_" + std::to_string(i + 1), &tensor_int32);
472+
split_outputs.push_back(arg);
473+
graph_outputs.push_back(arg);
480474
}
481-
475+
std::reverse(graph_outputs.begin(), graph_outputs.end());
482476
std::vector<NodeArg*> inputs;
483477
std::vector<NodeArg*> outputs;
484478

485479
inputs.push_back(&input_arg_a);
486480
outputs.push_back(&output_arg_a);
487-
expected_node_name_to_input_output_args["a"] = {inputs, outputs};
488481
graph.AddNode("a", "Identity_Fake", "a", inputs, outputs);
489482

490483
inputs.resize(2);
491484
inputs[0] = &output_arg_b;
492485
inputs[1] = &output_arg_a;
493486
outputs[0] = &output_arg_c;
494-
expected_node_name_to_input_output_args["c"] = {inputs, outputs};
495487
graph.AddNode("c", "Merge_Fake", "c", inputs, outputs);
496488

497489
// deliberately add 'b' after 'c' to mix up the inputs as well
498490
inputs.resize(1);
499491
inputs[0] = &input_arg_b;
500492
outputs[0] = &output_arg_b;
501-
expected_node_name_to_input_output_args["b"] = {inputs, outputs};
502493
graph.AddNode("b", "Identity_Fake", "b", inputs, outputs);
503494

504495
inputs[0] = &output_arg_c;
505-
expected_node_name_to_input_output_args["d"] = {inputs, split_outputs};
506496
graph.AddNode("d", "Split_Fake", "d", inputs, split_outputs);
507497

508-
auto validate_inputs_outputs = [&split_outputs](const Graph& graph) {
498+
auto validate_inputs_outputs = [&graph_outputs](const Graph& graph) {
509499
auto inputs = graph.GetInputs();
510500
auto outputs = graph.GetOutputs();
511501

@@ -516,10 +506,11 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
516506

517507
ASSERT_TRUE(outputs.size() == 10);
518508
for (int i = 0; i < 10; ++i) {
519-
EXPECT_TRUE(split_outputs[i]->Name() == outputs[i]->Name());
509+
EXPECT_TRUE(graph_outputs[i]->Name() == outputs[i]->Name());
520510
}
521511
};
522-
512+
graph.SetInputs({&input_arg_a, &input_arg_b});
513+
graph.SetOutputs(graph_outputs);
523514
auto status = graph.Resolve();
524515
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
525516

onnxruntime/test/providers/cpu/controlflow/loop_test.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ static const ONNX_NAMESPACE::GraphProto CreateSubgraph(const RunOptions& options
268268
}
269269
}
270270

271-
graph.SetInputOrder({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
272-
graph.SetOutputOrder({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
271+
graph.SetInputs({&iter_num_in, &cond_in, &loop_var_0_in, &loop_var_1_in});
272+
graph.SetOutputs({cond_out, loop_var_0_out, loop_var_1_out, loop_out_0});
273273

274274
// optional input backed by an initializer to make sure that's handled too.
275275
// we expect that Graph::InferAndVerifySubgraphTypes will be able to ignore the optional input if not provided
@@ -447,8 +447,8 @@ TEST(Loop, InfiniteLoopTermination) {
447447
graph.AddNode("loop_var_out", "Identity", "Forward outer_scope_0 to loop_var_0_out", inputs, outputs);
448448
}
449449

450-
graph.SetInputOrder({&iter_num_in, &cond_in, &outer_scope_0});
451-
graph.SetOutputOrder({&cond_out, &loop_var_0_out});
450+
graph.SetInputs({&iter_num_in, &cond_in, &outer_scope_0});
451+
graph.SetOutputs({&cond_out, &loop_var_0_out});
452452

453453
auto status = graph.Resolve();
454454
EXPECT_EQ(status, Status::OK());

onnxruntime/test/providers/cpu/controlflow/scan_test.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1046,8 +1046,8 @@ void MixedTypeInputs(bool is_v8) {
10461046
graph.AddNode("node3", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
10471047
graph.AddNode("node4", "Identity", "Copy scan_in_2 to state_out_2", {&scan_in_2}, {&state_out_2});
10481048

1049-
graph.SetInputOrder({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
1050-
graph.SetOutputOrder({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
1049+
graph.SetInputs({&state_in_1, &state_in_2, &scan_in_1, &scan_in_2});
1050+
graph.SetOutputs({&state_out_1, &state_out_2, &scan_out_1, &scan_out_2});
10511051

10521052
auto status = graph.Resolve();
10531053
EXPECT_EQ(status, Status::OK());
@@ -1108,8 +1108,8 @@ void UnknownDimInSubgraphOutput(bool is_v8) {
11081108
graph.AddNode("node1", "Identity", "Copy state_in_1 to scan_out_1", {&state_in_1}, {&scan_out_1});
11091109
graph.AddNode("node2", "Identity", "Copy scan_in_1 to state_out_1", {&scan_in_1}, {&state_out_1});
11101110

1111-
graph.SetInputOrder({&state_in_1, &scan_in_1});
1112-
graph.SetOutputOrder({&state_out_1, &scan_out_1});
1111+
graph.SetInputs({&state_in_1, &scan_in_1});
1112+
graph.SetOutputs({&state_out_1, &scan_out_1});
11131113

11141114
auto status = graph.Resolve();
11151115
EXPECT_EQ(status, Status::OK());

0 commit comments

Comments
 (0)