@@ -81,8 +81,13 @@ void getSegmentsOutputByRunning(
81
81
jit_inputs_ivalues.push_back (ivalues_maps[input].toList ());
82
82
} else if (input->type ()->kind () == torch::jit::TypeKind::TupleType) {
83
83
jit_inputs_ivalues.push_back (ivalues_maps[input].toTuple ());
84
+ } else if (input->type ()->kind () == torch::jit::TypeKind::NumberType) {
85
+ jit_inputs_ivalues.push_back (ivalues_maps[input].toScalar ());
84
86
} else {
85
- TORCHTRT_THROW_ERROR (" Unable to find type for value: " << input->debugName () << " to get the ivalues.\n " );
87
+ TORCHTRT_THROW_ERROR (
88
+ " Unable to find type for value: " << input->debugName ()
89
+ << " to get the ivalues. The type for this value should be "
90
+ << input->type ()->str () << " \n " );
86
91
}
87
92
}
88
93
@@ -110,28 +115,31 @@ void getSegmentsOutputByRunning(
110
115
for (auto & i : seg_block.raw_inputs ()) {
111
116
if (ivalues_maps[i].isTensor ()) {
112
117
// set the input_shape and data_type
113
- at::ScalarType t = ivalues_maps[i].toTensor ().scalar_type ();
118
+ // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
119
+ // shape inference
120
+ auto cur_ivalue = ivalues_maps[i];
121
+ at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
114
122
if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
115
123
TORCHTRT_THROW_ERROR (
116
124
" Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled" );
117
125
} else if (partition_info.truncate_long_and_double && t == at::kLong ) {
118
- ivalues_maps[i] = ivalues_maps[i] .toTensor ().to (at::kInt );
126
+ cur_ivalue = cur_ivalue .toTensor ().to (at::kInt );
119
127
LOG_WARNING (" Truncating graph input type from at::kLong to at::kInt" );
120
128
} else if (partition_info.truncate_long_and_double && t == at::kDouble ) {
121
- ivalues_maps[i] = ivalues_maps[i] .toTensor ().to (at::kFloat );
129
+ cur_ivalue = cur_ivalue .toTensor ().to (at::kFloat );
122
130
LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
123
131
}
124
- c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (ivalues_maps[i] .toTensor ().dtype ());
132
+ c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (cur_ivalue .toTensor ().dtype ());
125
133
if (dtype == c10::nullopt) {
126
- TORCHTRT_THROW_ERROR (" Unsupported input data type " << ivalues_maps[i] .toTensor ().dtype ());
134
+ TORCHTRT_THROW_ERROR (" Unsupported input data type " << cur_ivalue .toTensor ().dtype ());
127
135
}
128
- if (ivalues_maps[i] .toTensor ().sizes ().size () == 0 ) {
136
+ if (cur_ivalue .toTensor ().sizes ().size () == 0 ) {
129
137
// handle Scalar types, which has sizes of []
130
138
input_shapes.push_back (util::toVec (util::toDims (c10::List<long int >({1 }))));
131
139
} else {
132
- input_shapes.push_back (util::toVec (util::toDims (ivalues_maps[i] .toTensor ().sizes ())));
140
+ input_shapes.push_back (util::toVec (util::toDims (cur_ivalue .toTensor ().sizes ())));
133
141
}
134
- input_types.push_back (ivalues_maps[i] .toTensor ().scalar_type ());
142
+ input_types.push_back (cur_ivalue .toTensor ().scalar_type ());
135
143
}
136
144
}
137
145
0 commit comments