Skip to content

Commit ef0b713

Browse files
smk2007Sheil Kumar
and
Sheil Kumar
authored
Optimize KahnsTopologicalSort and PriorityNodeCompare (#19475)
**Description** 1) During SessionInitialization, KahnsTopologicalSort is a major cause of perf degradation. The main cause of slow down is that the TopologicalSort needs to keep track of nodes to visit in order, and reorder them based on priority (as informed by a comparator). The existing implementation uses a priority_queue that is backed by a std::vector container. However, vectors are not good for insertion and reordering. The appropriate data type for this operation is a linked list. However, linked lists like std::list are not usable as a container for std::priority_queue. This is because std::priority_queue requires random access, which linked lists do not have. However, for this simple implementation, we can leverage a std::list under the hood and perform insertions manually using std::upper_bound. This drastically reduces the time taken by the method, which currently instead causes numerous recopies and a lot of movement inside the graph nodes to visit list. 2) In the comparator, I hide forward and backward attribute checking behind the #ifdef ENABLE_TRAINING macro, as I believe it should only be valid in the training scenario. 3) In noopelimination transformer, I prevent the creation of Initializer (which unpacks tensorproto data) in every node and only create initializers when Add/Sub/Mul/Div op nodes are detected. **Motivation and Context** Session creation time of many models is quite slow. --------- Co-authored-by: Sheil Kumar <[email protected]>
1 parent 4bfa69d commit ef0b713

File tree

4 files changed

+85
-45
lines changed

4 files changed

+85
-45
lines changed

onnxruntime/core/graph/graph.cc

+29-8
Original file line numberDiff line numberDiff line change
@@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span<const Node* const> from,
18181818
}
18191819
}
18201820

1821+
template <typename T>
1822+
struct VisitorPriorityQueue {
1823+
using ComparatorType = std::function<bool(T, T)>;
1824+
std::list<T> list_;
1825+
const ComparatorType comparator_ = nullptr;
1826+
VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {}
1827+
1828+
void push(T node) {
1829+
list_.insert(
1830+
std::upper_bound(list_.begin(), list_.end(), node, comparator_),
1831+
node);
1832+
}
1833+
bool empty() { return list_.empty(); }
1834+
T top() { return list_.back(); }
1835+
void pop() { list_.pop_back(); }
1836+
};
1837+
18211838
#if !defined(ORT_MINIMAL_BUILD)
18221839
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
18231840
const std::function<bool(const Node*, const Node*)>& comp) const {
1824-
std::unordered_map<NodeIndex, size_t> in_degree;
1825-
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
1826-
std::vector<NodeIndex> topo_order;
1841+
InlinedVector<size_t> in_degree(MaxNodeIndex(), 0);
1842+
InlinedVector<NodeIndex> topo_order;
1843+
VisitorPriorityQueue<const Node*> to_visit(comp);
1844+
1845+
auto number_of_nodes = NumberOfNodes();
1846+
topo_order.reserve(number_of_nodes);
18271847

18281848
for (auto& node : Nodes()) {
18291849
size_t input_edge_count = node.GetInputEdgesCount();
1830-
in_degree.insert({node.Index(), input_edge_count});
1850+
in_degree[node.Index()] = input_edge_count;
18311851
if (input_edge_count == 0) {
18321852
to_visit.push(&node);
18331853
}
@@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
18441864
}
18451865

18461866
for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) {
1847-
in_degree[node_it->Index()]--;
1867+
auto& node_in_degree = in_degree[node_it->Index()];
1868+
node_in_degree--;
18481869

1849-
if (in_degree[node_it->Index()] == 0) {
1870+
if (node_in_degree == 0) {
18501871
to_visit.push(&*node_it);
18511872
}
18521873
}
18531874
topo_order.push_back(current->Index());
18541875
}
18551876

1856-
if (NumberOfNodes() != static_cast<int>(topo_order.size())) {
1877+
if (number_of_nodes != static_cast<int>(topo_order.size())) {
18571878
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
18581879
}
18591880
}
@@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
28432864

28442865
const gsl::not_null<TensorProto*> tensor_added{graph_proto_->add_initializer()};
28452866
*(tensor_added) = tensor;
2846-
name_to_initial_tensor_[tensor.name()] = tensor_added;
2867+
name_to_initial_tensor_.emplace(tensor.name(), tensor_added);
28472868
SetGraphResolveNeeded();
28482869
if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) {
28492870
// make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs.

onnxruntime/core/graph/graph_viewer.cc

+12-6
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
1414
struct PriorityNodeCompare {
1515
inline bool IsHighPri(const Node* n) const {
1616
// local statics so we can compare std::strings in the checks
17-
static const std::string shape_op("Shape");
18-
static const std::string size_op("Size");
17+
static constexpr std::string_view shape_op("Shape");
18+
static constexpr std::string_view size_op("Size");
1919

2020
const auto& op_type = n->OpType();
2121
return op_type == shape_op || op_type == size_op;
@@ -26,15 +26,20 @@ struct PriorityNodeCompare {
2626
// If return true, n2 will be output first
2727
bool operator()(const Node* n1, const Node* n2) const {
2828
// nodes in global high priority list will be output first
29-
if (IsHighPri(n1) != IsHighPri(n2)) {
30-
return IsHighPri(n2);
29+
const bool isN1HighPri = IsHighPri(n1);
30+
const bool isN2HighPri = IsHighPri(n2);
31+
if (isN1HighPri != isN2HighPri) {
32+
return isN2HighPri;
3133
}
3234

3335
// nodes with lower priority value will be output first
34-
if (n1->Priority() != n2->Priority()) {
35-
return n1->Priority() > n2->Priority();
36+
const auto n1_priority = n1->Priority();
37+
const auto n2_priority = n2->Priority();
38+
if (n1_priority != n2_priority) {
39+
return n1_priority > n2_priority;
3640
}
3741

42+
#ifdef ENABLE_TRAINING
3843
// nodes of forward pass will be output first
3944
auto n1_attrs = n1->GetAttributes();
4045
auto n2_attrs = n2->GetAttributes();
@@ -45,6 +50,7 @@ struct PriorityNodeCompare {
4550
if (n1_is_forward != n2_is_forward) {
4651
return n2_is_forward > n1_is_forward;
4752
}
53+
#endif
4854

4955
// otherwise, nodes with lower index will be output first
5056
return n1->Index() > n2->Index();

onnxruntime/core/optimizer/noop_elimination.cc

+43-30
Original file line numberDiff line numberDiff line change
@@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con
4242

4343
// if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule,
4444
// but it won't happen if the case is accepted, thus reject it
45-
auto initializer_rank = initializer->dims().size();
45+
const auto& dims = initializer->dims();
46+
auto initializer_rank = dims.size();
4647
const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape();
4748
if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) {
4849
return false;
4950
}
5051

51-
int32_t data_type = initializer->data_type();
52-
Initializer add_init(*initializer, graph.ModelPath());
53-
if (add_init.size() > 1) {
52+
int64_t tensor_size = 1;
53+
for (auto i : dims) {
54+
tensor_size *= i;
55+
}
56+
57+
if (tensor_size > 1) {
5458
return false;
5559
}
60+
5661
// handle edge case where the total size of the initializer is 0
57-
if (add_init.size() == 0) {
62+
if (tensor_size == 0) {
5863
return true;
5964
}
6065

61-
float value = 0.0f;
62-
switch (data_type) {
63-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
64-
value = *add_init.data<float>();
65-
break;
66-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
67-
value = math::halfToFloat(add_init.data<MLFloat16>()->val);
68-
break;
69-
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
70-
value = static_cast<float>(*add_init.data<double>());
71-
break;
72-
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
73-
value = static_cast<float>(*add_init.data<int32_t>());
74-
break;
75-
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
76-
value = static_cast<float>(*add_init.data<int64_t>());
77-
break;
78-
default:
66+
if (op_type == "Add" ||
67+
op_type == "Sub" ||
68+
op_type == "Mul" ||
69+
op_type == "Div") {
70+
int32_t data_type = initializer->data_type();
71+
Initializer add_init(*initializer, graph.ModelPath());
72+
73+
float value = 0.0f;
74+
switch (data_type) {
75+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
76+
value = *add_init.data<float>();
77+
break;
78+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
79+
value = math::halfToFloat(add_init.data<MLFloat16>()->val);
80+
break;
81+
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
82+
value = static_cast<float>(*add_init.data<double>());
83+
break;
84+
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
85+
value = static_cast<float>(*add_init.data<int32_t>());
86+
break;
87+
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
88+
value = static_cast<float>(*add_init.data<int64_t>());
89+
break;
90+
default:
91+
return false;
92+
}
93+
94+
if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) {
7995
return false;
80-
}
96+
}
8197

82-
if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) {
83-
return false;
84-
}
85-
86-
if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) {
87-
return false;
98+
if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) {
99+
return false;
100+
}
88101
}
89102

90103
// reject node output is graph output for now

onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef {
115115
const auto& graph_outputs = graph_.GetOutputs();
116116
graph_outputs_.reserve(graph_outputs.size());
117117
for (const auto* output : graph_outputs) {
118-
graph_outputs_.insert(output->Name());
118+
graph_outputs_.emplace(output->Name());
119119
}
120120
}
121121

0 commit comments

Comments
 (0)