@@ -9,19 +9,46 @@ namespace trtorch {
9
9
namespace core {
10
10
namespace partitioning {
11
11
12
- inline bool isTensorOrTensorList (torch::jit::Value* val) {
13
- return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
14
- val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
15
- }
16
-
17
12
struct usage_info {
18
13
int produce_id = -1 ;
19
14
std::vector<int > torch_use_id;
20
15
std::vector<int > tensorrt_use_id;
21
16
};
22
17
18
+ inline bool isTensorOrTensorList (torch::jit::Value* val) {
19
+ return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
20
+ val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
21
+ }
22
+
23
+ bool isAllNodesSupported (const std::vector<torch::jit::Node*>& nodes) {
24
+ for (auto node : nodes) {
25
+ if (!conversion::OpSupported (node)) {
26
+ return false ;
27
+ }
28
+ }
29
+ return true ;
30
+ }
31
+
32
+ bool containNonTensorInputs (torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
33
+ for (auto input : n->inputs ()) {
34
+ if (!isTensorOrTensorList (input) && target_inputs.count (input)) {
35
+ return true ;
36
+ }
37
+ }
38
+ return false ;
39
+ }
40
+
41
+ bool containNonTensorOutputs (torch::jit::Node* n) {
42
+ for (auto output : n->outputs ()) {
43
+ if (!isTensorOrTensorList (output)) {
44
+ return true ;
45
+ }
46
+ }
47
+ return false ;
48
+ }
49
+
23
50
std::vector<torch::jit::Node*> getDependencyNodes (std::vector<torch::jit::Value*>& vals) {
24
- // using bfs to get the DAG dependency nodes for input value
51
+ // use bfs to get the DAG dependency nodes for input value
25
52
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q (
26
53
std::deque<torch::jit::Value*>(vals.begin (), vals.end ()));
27
54
std::unordered_set<torch::jit::Node*> visited;
@@ -43,17 +70,50 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
43
70
return stk;
44
71
}
45
72
46
- SegmentedBlock injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
73
+ std::vector< SegmentedBlock> injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
47
74
// reconstruct segmented_block if this block requires nonTensor input
48
75
std::vector<torch::jit::Value*> nontensor_inputs;
49
76
for (auto input : seg_block.raw_inputs ()) {
50
77
if (!isTensorOrTensorList (input)) {
51
78
nontensor_inputs.push_back (input);
52
79
}
53
80
}
54
- std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes (nontensor_inputs);
55
- new_block_nodes.insert (new_block_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
56
- return std::move (SegmentedBlock (seg_block.target (), new_block_nodes));
81
+ std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes (nontensor_inputs);
82
+
83
+ std::vector<SegmentedBlock> new_seg_blocks;
84
+ // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85
+ // one new block
86
+ if (seg_block.target () == SegmentedBlock::kTorch || isAllNodesSupported (dependency_nodes)) {
87
+ dependency_nodes.insert (dependency_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
88
+ new_seg_blocks.emplace_back (seg_block.target (), dependency_nodes);
89
+ } else {
90
+ // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
91
+ std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
92
+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , dependency_nodes);
93
+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
94
+ bool prev_non_tensor_outputs = false ;
95
+ for (auto n : seg_block.raw_nodes ()) {
96
+ // it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97
+ if (containNonTensorInputs (n, nontensor_inputs_set) || prev_non_tensor_outputs) {
98
+ if (!tensorrt_nodes.empty ()) {
99
+ new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
100
+ }
101
+ pytorch_nodes.push_back (n);
102
+ prev_non_tensor_outputs = containNonTensorOutputs (n);
103
+ } else {
104
+ if (!pytorch_nodes.empty ()) {
105
+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
106
+ }
107
+ tensorrt_nodes.push_back (n);
108
+ }
109
+ }
110
+ if (!tensorrt_nodes.empty ()) {
111
+ new_seg_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
112
+ } else {
113
+ new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
114
+ }
115
+ }
116
+ return std::move (new_seg_blocks);
57
117
}
58
118
59
119
void resolveNonTensorInputs (PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
@@ -80,16 +140,17 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
80
140
if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
81
141
int first_torch_id = use_info.torch_use_id .front ();
82
142
if (!updated_segments.count (first_torch_id)) {
83
- auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]);
143
+ auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]). front () ;
84
144
segmented_blocks[first_torch_id] = new_torch_block;
85
145
updated_segments.insert (first_torch_id);
86
146
}
87
147
} else {
88
148
// KTensorRT segments always need to inject nodes for the nonTensor inputs
89
149
for (int i : use_info.tensorrt_use_id ) {
90
150
if (!updated_segments.count (i)) {
91
- auto new_seg_block = injectNodesForNonTensorInputs (segmented_blocks[i]);
92
- segmented_blocks[i] = new_seg_block;
151
+ auto to_inject_blocks = injectNodesForNonTensorInputs (segmented_blocks[i]);
152
+ segmented_blocks.erase (segmented_blocks.begin () + i);
153
+ segmented_blocks.insert (segmented_blocks.begin () + i, to_inject_blocks.begin (), to_inject_blocks.end ());
93
154
updated_segments.insert (i);
94
155
}
95
156
}
0 commit comments