@@ -940,11 +940,8 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
940
940
// and not explicitly listed in the ordered graph outputs (as that implies we should leave it as an output).
941
941
// If the Graph was loaded from a GraphProto, honor the explicit graph outputs and leave as is.
942
942
if (!loaded_from_model_file) {
943
- auto in_ordered_graph_outputs = find (graph_output_order_.cbegin (), graph_output_order_.cend (), node_arg);
944
- if (in_ordered_graph_outputs == graph_output_order_.cend ()) {
945
- graph_outputs_.erase (std::remove (graph_outputs_.begin (), graph_outputs_.end (), node_arg),
946
- graph_outputs_.end ());
947
- }
943
+ graph_outputs_.erase (std::remove (graph_outputs_.begin (), graph_outputs_.end (), node_arg),
944
+ graph_outputs_.end ());
948
945
}
949
946
}
950
947
}
@@ -2219,10 +2216,8 @@ void Graph::CleanUnusedInitializers() {
2219
2216
2220
2217
GSL_SUPPRESS (es .84 ) // warning about ignoring return value from insert(...)
2221
2218
Status Graph::SetGraphInputsOutputs() {
2222
- // Reset graph inputs/outputs/value info state .
2219
+ // Reset graph inputs excluding initializers/value_info .
2223
2220
graph_inputs_excluding_initializers_.clear ();
2224
- graph_inputs_including_initializers_.clear ();
2225
- graph_outputs_.clear ();
2226
2221
value_info_.clear ();
2227
2222
2228
2223
// Flag indicates that this graph is loaded from model file.
@@ -2231,10 +2226,11 @@ Status Graph::SetGraphInputsOutputs() {
2231
2226
// and outputs will be inferred.
2232
2227
const bool loaded_from_model_file = GraphLoadedFromModelFile (graph_proto_);
2233
2228
2234
- // if something is coming from outer scope, consider it already added
2235
- std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
2236
-
2237
2229
if (loaded_from_model_file) {
2230
+ // Reset graph inputs/outputs.
2231
+ graph_inputs_including_initializers_.clear ();
2232
+ graph_outputs_.clear ();
2233
+
2238
2234
// Name to NodeArg mapping of all graph initializers.
2239
2235
std::unordered_map<std::string, const NodeArg*> graph_initializers;
2240
2236
@@ -2302,49 +2298,31 @@ Status Graph::SetGraphInputsOutputs() {
2302
2298
}
2303
2299
2304
2300
} else {
2305
- std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg ;
2306
- std::vector<std::string> ordered_output_names ;
2301
+ std::unordered_map<std::string, size_t > output_name_to_node_arg_index ;
2302
+ std::vector<const NodeArg*> output_node_args_in_order ;
2307
2303
2308
- // add any explicitly ordered inputs
2309
- for (auto * node_arg : graph_input_order_) {
2310
- if (!node_arg || !node_arg->Exists ()) {
2311
- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid entry in explicitly ordered inputs" );
2312
- }
2313
-
2314
- added_input_names.insert (node_arg->Name ());
2315
- graph_inputs_including_initializers_.push_back (node_arg);
2316
- if (name_to_initial_tensor_.find (node_arg->Name ()) == name_to_initial_tensor_.end ()) {
2317
- graph_inputs_excluding_initializers_.push_back (node_arg);
2318
- }
2304
+ // if something is coming from outer scope, consider it already added
2305
+ std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
2306
+ if (!graph_inputs_manually_set_) {
2307
+ graph_inputs_including_initializers_.clear ();
2319
2308
}
2320
2309
2321
- // add any explicitly ordered outputs
2322
- for (auto * node_arg : graph_output_order_) {
2323
- if (!node_arg || !node_arg->Exists ()) {
2324
- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Invalid entry in explicitly ordered outputs" );
2325
- }
2326
- output_name_to_node_arg.insert ({node_arg->Name (), node_arg});
2327
- ordered_output_names.push_back (node_arg->Name ());
2310
+ if (!graph_outputs_manually_set_) {
2311
+ graph_outputs_.clear ();
2328
2312
}
2329
2313
2330
- // add all other outputs
2314
+ // Collect all nodes' outputs
2331
2315
for (const auto & node : Nodes ()) {
2332
2316
for (const auto * output_def : node.OutputDefs ()) {
2333
2317
if (output_def->Exists ()) {
2334
- auto & name = output_def->Name ();
2335
- // check it wasn't in the explicitly ordered outputs
2336
- if (output_name_to_node_arg.find (name) == output_name_to_node_arg.cend ()) {
2337
- output_name_to_node_arg.insert ({name, output_def});
2338
- ordered_output_names.push_back (name);
2339
- }
2318
+ output_node_args_in_order.push_back (output_def);
2319
+ output_name_to_node_arg_index.insert ({output_def->Name (), output_node_args_in_order.size () - 1 });
2340
2320
}
2341
2321
}
2342
2322
}
2343
2323
2344
2324
// Init graph output args with copy of all node output args.
2345
- auto graph_output_args = output_name_to_node_arg;
2346
- std::unordered_set<Node*> inner_nodes;
2347
-
2325
+ auto graph_output_args = output_name_to_node_arg_index;
2348
2326
for (const auto & node : Nodes ()) {
2349
2327
// Go thru all node's inputs.
2350
2328
for (const auto * input_arg : node.InputDefs ()) {
@@ -2353,15 +2331,28 @@ Status Graph::SetGraphInputsOutputs() {
2353
2331
continue ;
2354
2332
}
2355
2333
2356
- auto output_arg_iter = output_name_to_node_arg .find (input_arg->Name ());
2357
- if (output_name_to_node_arg .end () == output_arg_iter) {
2334
+ auto output_arg_iter = output_name_to_node_arg_index .find (input_arg->Name ());
2335
+ if (output_name_to_node_arg_index .end () == output_arg_iter) {
2358
2336
// This input arg should be fed when running evaluation.
2359
2337
// it should be a graph input.
2360
2338
const std::string& name = input_arg->Name ();
2361
2339
if (added_input_names.end () == added_input_names.find (name)) {
2362
2340
// This graph input has not been added into <graph_inputs_>.
2363
- graph_inputs_including_initializers_.push_back (input_arg);
2364
-
2341
+ if (!graph_inputs_manually_set_) {
2342
+ graph_inputs_including_initializers_.push_back (input_arg);
2343
+ } else {
2344
+ // Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
2345
+ auto & inputs = GetInputsIncludingInitializers ();
2346
+ auto iter = std::find (inputs.begin (), inputs.end (), input_arg);
2347
+ if (inputs.end () == iter) {
2348
+ // it's not in graph inputs.
2349
+ auto initializers = GetAllInitializedTensors ();
2350
+ if (initializers.end () == initializers.find (input_arg->Name ())) {
2351
+ // It's not in graph initializers.
2352
+ return Status (ONNXRUNTIME, FAIL, input_arg->Name () + " must be either specified in graph inputs or graph initailizers." );
2353
+ }
2354
+ }
2355
+ }
2365
2356
if (name_to_initial_tensor_.find (name) == name_to_initial_tensor_.end ()) {
2366
2357
graph_inputs_excluding_initializers_.push_back (input_arg);
2367
2358
}
@@ -2377,12 +2368,15 @@ Status Graph::SetGraphInputsOutputs() {
2377
2368
}
2378
2369
}
2379
2370
2380
- // Set graph outputs
2381
- auto end = graph_output_args.end ();
2382
- for (auto & name : ordered_output_names) {
2383
- auto graph_output = graph_output_args.find (name);
2384
- if (graph_output != end) {
2385
- graph_outputs_.push_back (graph_output->second );
2371
+ if (!graph_outputs_manually_set_) {
2372
+ // Set graph outputs in order.
2373
+ std::vector<size_t > graph_output_args_index;
2374
+ for (auto output_arg : graph_output_args) {
2375
+ graph_output_args_index.push_back (output_arg.second );
2376
+ }
2377
+ std::sort (graph_output_args_index.begin (), graph_output_args_index.end ());
2378
+ for (auto & output_arg_index : graph_output_args_index) {
2379
+ graph_outputs_.push_back (output_node_args_in_order[output_arg_index]);
2386
2380
}
2387
2381
}
2388
2382
}
@@ -2483,6 +2477,25 @@ Status Graph::InlineFunction(Node& node) {
2483
2477
return Status::OK ();
2484
2478
}
2485
2479
2480
+ void Graph::SetInputs (const std::vector<const NodeArg*> inputs) {
2481
+ if (GraphLoadedFromModelFile (graph_proto_)) {
2482
+ // TODO: add this support.
2483
+ ORT_THROW (" This API is not supported when model is loaded from proto file right now." );
2484
+ }
2485
+
2486
+ graph_inputs_including_initializers_ = inputs;
2487
+ graph_inputs_manually_set_ = true ;
2488
+ }
2489
+
2490
+ void Graph::SetOutputs (const std::vector<const NodeArg*> outputs) {
2491
+ if (GraphLoadedFromModelFile (graph_proto_)) {
2492
+ // TODO: add this support.
2493
+ ORT_THROW (" This API is not supported when model is loaded from proto file right now." );
2494
+ }
2495
+ graph_outputs_ = outputs;
2496
+ graph_outputs_manually_set_ = true ;
2497
+ }
2498
+
2486
2499
void Graph::AddFunction (const ONNX_NAMESPACE::FunctionProto* func_proto) {
2487
2500
this ->model_functions_ [func_proto->name ()] = func_proto;
2488
2501
}
0 commit comments