Skip to content

Commit 52abece

Browse files
committed
chore: Address review comments and remove commented code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent f9059e0 commit 52abece

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

core/compiler.cpp

-15
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
372372
// Infer the type of an input from the weights of the calculation
373373
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
374374

375-
// // GPU default WS size : 1 GB
376-
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
377-
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
378-
// auto device_spec = cfg.convert_info.engine_settings.device;
379-
// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
380-
// if (workspace_size == 0) {
381-
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
382-
// }
383-
384375
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
385376

386377
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
@@ -391,14 +382,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
391382
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
392383
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
393384

394-
// // GPU default WS size : 1 GB
395-
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
396-
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
397385
auto device_spec = cfg.convert_info.engine_settings.device;
398386
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
399-
// if (workspace_size == 0) {
400-
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
401-
// }
402387

403388
for (const torch::jit::Method& method : mod.get_methods()) {
404389
if (method.name().compare("forward") == 0) {

core/partitioning/partitioning.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ void find_all_fallback_nodes(
115115
// for every node that produces this fallback node's NonTensor input, they should fallback too
116116
for (auto input : cur_node->inputs()) {
117117
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
118-
global_fallback_nodes.insert({input->node(), 4}).second) {
118+
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
119119
q.push(input->node());
120120
}
121121
}
@@ -124,7 +124,7 @@ void find_all_fallback_nodes(
124124
if (!isTensor(output)) {
125125
for (auto use : output->uses()) {
126126
auto node = use.user;
127-
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, 4}).second) {
127+
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
128128
q.push(node);
129129
}
130130
}
@@ -229,13 +229,13 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
229229

230230
bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
231231
if (fallback_nodes.count(n)) {
232-
if (fallback_nodes.at(n) == 0) {
232+
if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) {
233233
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
234-
} else if (fallback_nodes.at(n) == 1) {
234+
} else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) {
235235
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
236-
} else if (fallback_nodes.at(n) == 2) {
236+
} else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) {
237237
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
238-
} else if (fallback_nodes.at(n) == 3) {
238+
} else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) {
239239
LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
240240
} else {
241241
LOG_GRAPH(
@@ -273,18 +273,18 @@ void get_fallback_nodes(
273273

274274
// If the op is not supported by the conversion phase it should run in PyTorch
275275
if (!conversion::OpSupported(n)) {
276-
fallback_nodes.insert({n, 0});
276+
fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED});
277277
}
278278

279279
// If the user specifies the op to run in Torch it should run in PyTorch
280280
if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) {
281-
fallback_nodes.insert({n, 1});
281+
fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK});
282282
}
283283

284284
// If the user specifies the module containing this op to run in torch it should run in PyTorch
285285
const auto to_compile_sym = c10::Symbol::attr("to_compile");
286286
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
287-
fallback_nodes.insert({n, 2});
287+
fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK});
288288
}
289289
}
290290
return;
@@ -329,7 +329,7 @@ void find_min_block_size_fallback_nodes(
329329
// keep fallback until all segments meet the min_block_size requirement
330330
while (!min_block_fallback_nodes.empty()) {
331331
for (const auto i : min_block_fallback_nodes) {
332-
initial_fallback_nodes.insert({i, 3});
332+
initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK});
333333
}
334334
global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
335335
// find the fallback nodes because of dependency with min_block_size caused fallback nodes

core/partitioning/partitioning.h

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ namespace partitioning {
1616

1717
typedef std::vector<SegmentedBlock> PartitionedGraph;
1818

19+
enum FallbackNodeType {
20+
/// Node is not supported by TensorRT
21+
kUNSUPPORTED,
22+
/// Node is explicitly forced to fallback to Pytorch due to operator fallback
23+
kOPERATOR_FALLBACK,
24+
/// Node is explicitly forced to fallback to Pytorch due to module fallback
25+
kMODULE_FALLBACK,
26+
/// This node is in a TRT segment which does not satisfy min_block_size
27+
/// and hence is forced to fallback.
28+
kMIN_BLOCK_FALLBACK,
29+
/// This node produces/consumes non-tensor inputs
30+
kNON_TENSOR,
31+
};
32+
1933
PartitionedGraph segment_graph(
2034
torch::jit::Block* block,
2135
const PartitionInfo& partition_info,

0 commit comments

Comments
 (0)