@@ -25,7 +25,8 @@ std::unordered_map<torch::jit::Value*, torch::jit::IValue> generateRandomInputs(
25
25
26
26
void getSegmentsOutputByRunning (
27
27
SegmentedBlock& seg_block,
28
- std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
28
+ std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
29
+ const PartitionInfo& partition_info) {
29
30
// create a module to run the graph
30
31
auto g = seg_block.g ();
31
32
auto copy_g = g->copy ();
@@ -99,10 +100,20 @@ void getSegmentsOutputByRunning(
99
100
for (auto & i : seg_block.raw_inputs ()) {
100
101
if (ivalues_maps[i].isTensor ()) {
101
102
// set the input_shape and data_type
103
+ at::ScalarType t = c10::optTypeMetaToScalarType (ivalues_maps[i].toTensor ().dtype ()).value ();
104
+ if (!partition_info.truncate_long_and_double &&
105
+ (t == at::kLong || t == at::kDouble )) {
106
+ TRTORCH_THROW_ERROR (
107
+ " Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled" );
108
+ } else if (partition_info.truncate_long_and_double && t == at::kLong ) {
109
+ ivalues_maps[i] = ivalues_maps[i].toTensor ().to (at::kInt );
110
+ } else if (partition_info.truncate_long_and_double && t == at::kDouble ) {
111
+ ivalues_maps[i] = ivalues_maps[i].toTensor ().to (at::kFloat );
112
+ }
102
113
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (ivalues_maps[i].toTensor ().dtype ());
103
114
nvinfer1::DataType nv_dtype;
104
115
if (dtype == c10::nullopt) {
105
- nv_dtype = nvinfer1::DataType:: kFLOAT ;
116
+ TRTORCH_THROW_ERROR ( " Unsupported input data type " << ivalues_maps[i]. toTensor (). dtype ()) ;
106
117
} else {
107
118
nv_dtype = dtype.value ();
108
119
}
@@ -116,11 +127,12 @@ void getSegmentsOutputByRunning(
116
127
117
128
void runShapeAnalysis (
118
129
std::vector<SegmentedBlock>& segmented_blocks,
119
- std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
130
+ std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
131
+ const PartitionInfo& partition_info) {
120
132
// register every segment's input shape, and it's running output IValues
121
133
for (auto & seg_block : segmented_blocks) {
122
134
torch::jit::ConstantPooling (seg_block.g ());
123
- getSegmentsOutputByRunning (seg_block, ivalues_maps);
135
+ getSegmentsOutputByRunning (seg_block, ivalues_maps, partition_info );
124
136
}
125
137
return ;
126
138
}
0 commit comments