Skip to content

Commit b3589c5

Browse files
committed
feat(//core/partitioing): Adding ostream for Partition Info
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent fb1a299 commit b3589c5

File tree

4 files changed

+14
-1
lines changed

4 files changed

+14
-1
lines changed

Diff for: core/partitioning/PartitionInfo.cpp

Whitespace-only changes.

Diff for: core/partitioning/PartitionInfo.h

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ struct PartitionInfo {
1414
std::vector<std::string> forced_fallback_operators;
1515
};
1616

17+
std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);
18+
1719
} // namespace partitioning
1820
} // namespace core
1921
} // namespace trtorch

Diff for: core/partitioning/partitioning.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ std::vector<SegmentedBlock> Partition(
204204
std::shared_ptr<torch::jit::Graph> g,
205205
std::vector<ir::InputRange>& input_ranges,
206206
const PartitionInfo& partition_info) {
207+
208+
LOG_DEBUG(partition_info);
207209
// segment lowering global graph into blocks
208210
std::vector<SegmentedBlock> segmented_blocks = segment_graph(g, partition_info);
209211

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
108108
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
109109
info.convert_info.engine_settings.device.dla_core = device.dla_core;
110110
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
111-
info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled;
112111
info.partition_info.enabled = torch_fallback.enabled;
113112
info.partition_info.min_block_size = torch_fallback.min_block_size;
114113
info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
114+
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
115115

116116
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
117117
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
@@ -148,6 +148,15 @@ std::string CompileSpec::stringify() {
148148
ss << " \"Workspace Size\": " << workspace_size << std::endl;
149149
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
150150
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
151+
ss << " \"Torch Fallback: {" << std::endl;
152+
ss << " \"enabled\": " << torch_fallback.enabled ? "True" : "False" << std::endl;
153+
ss << " \"min_block_size\": " << torch_fallback.min_block_size << std::endl;
154+
ss << " \"forced_fallback_operators\": [" << std::endl;
155+
for (auto i : torch_fallback.forced_fallback_operators) {
156+
ss << " " << i << ',' << std::endl;
157+
}
158+
ss << " ]" << std::endl;
159+
ss << " }" << std::endl;
151160
ss << "}";
152161
return ss.str();
153162
}

0 commit comments

Comments
 (0)