Skip to content

Commit 0bc3c05

Browse files
committed
feat: support truncate_long_and_double in fallback subgraph input type
Signed-off-by: inocsin <[email protected]>
1 parent 4778b2b commit 0bc3c05

File tree

6 files changed

+22
-8
lines changed

6 files changed

+22
-8
lines changed

Diff for: core/partitioning/PartitionInfo.h

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ struct PartitionInfo {
1212
bool enabled = false;
1313
uint64_t min_block_size = 1;
1414
std::vector<std::string> forced_fallback_operators;
15+
bool truncate_long_and_double;
1516
};
1617

1718
std::ostream& operator<<(std::ostream& os, const PartitionInfo& s);

Diff for: core/partitioning/partitioning.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ std::vector<SegmentedBlock> Partition(
358358
registerSegmentsOutputs(segmented_blocks, block);
359359

360360
// run shape analysis on each segmented block
361-
runShapeAnalysis(segmented_blocks, input_ivalues_map);
361+
runShapeAnalysis(segmented_blocks, input_ivalues_map, partition_info);
362362

363363
return segmented_blocks;
364364
}

Diff for: core/partitioning/shape_analysis.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ std::unordered_map<torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
2525

2626
void getSegmentsOutputByRunning(
2727
SegmentedBlock& seg_block,
28-
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
28+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
29+
const PartitionInfo& partition_info) {
2930
// create a module to run the graph
3031
auto g = seg_block.g();
3132
auto copy_g = g->copy();
@@ -99,10 +100,20 @@ void getSegmentsOutputByRunning(
99100
for (auto& i : seg_block.raw_inputs()) {
100101
if (ivalues_maps[i].isTensor()) {
101102
// set the input_shape and data_type
103+
at::ScalarType t = c10::optTypeMetaToScalarType(ivalues_maps[i].toTensor().dtype()).value();
104+
if (!partition_info.truncate_long_and_double &&
105+
(t == at::kLong || t == at::kDouble)) {
106+
TRTORCH_THROW_ERROR(
107+
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
108+
} else if(partition_info.truncate_long_and_double && t == at::kLong) {
109+
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kInt);
110+
} else if(partition_info.truncate_long_and_double && t == at::kDouble) {
111+
ivalues_maps[i] = ivalues_maps[i].toTensor().to(at::kFloat);
112+
}
102113
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(ivalues_maps[i].toTensor().dtype());
103114
nvinfer1::DataType nv_dtype;
104115
if (dtype == c10::nullopt) {
105-
nv_dtype = nvinfer1::DataType::kFLOAT;
116+
TRTORCH_THROW_ERROR("Unsupported input data type " << ivalues_maps[i].toTensor().dtype());
106117
} else {
107118
nv_dtype = dtype.value();
108119
}
@@ -116,11 +127,12 @@ void getSegmentsOutputByRunning(
116127

117128
void runShapeAnalysis(
118129
std::vector<SegmentedBlock>& segmented_blocks,
119-
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
130+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
131+
const PartitionInfo& partition_info) {
120132
// register every segment's input shape, and it's running output IValues
121133
for (auto& seg_block : segmented_blocks) {
122134
torch::jit::ConstantPooling(seg_block.g());
123-
getSegmentsOutputByRunning(seg_block, ivalues_maps);
135+
getSegmentsOutputByRunning(seg_block, ivalues_maps, partition_info);
124136
}
125137
return;
126138
}

Diff for: core/partitioning/shape_analysis.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ std::unordered_map<torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
1111

1212
void runShapeAnalysis(
1313
std::vector<SegmentedBlock>& segmented_blocks,
14-
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps);
14+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
15+
const PartitionInfo& partition_info);
1516

1617
} // namespace partitioning
1718
} // namespace core

Diff for: core/util/trt_util.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
239239
{at::kHalf, nvinfer1::DataType::kHALF},
240240
{at::kInt, nvinfer1::DataType::kINT32},
241241
{at::kChar, nvinfer1::DataType::kINT8},
242-
{at::kBool, nvinfer1::DataType::kBOOL},
243-
{at::kLong, nvinfer1::DataType::kINT32},
242+
{at::kBool, nvinfer1::DataType::kBOOL}
244243
};
245244
return at_trt_type_map;
246245
}

Diff for: cpp/src/compile_spec.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
344344
internal.partition_info.enabled = external.torch_fallback.enabled;
345345
internal.partition_info.min_block_size = external.torch_fallback.min_block_size;
346346
internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops;
347+
internal.partition_info.truncate_long_and_double = external.truncate_long_and_double;
347348
internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules;
348349

349350
switch (external.device.device_type) {

0 commit comments

Comments
 (0)