@@ -99,13 +99,18 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
99
99
return nullptr ;
100
100
}
101
101
102
- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102
+ torch::jit::Node* createCastNode (
103
+ SegmentedBlock& seg_block,
104
+ size_t index,
105
+ bool is_input,
106
+ std::string device,
107
+ bool force_create_node = false ) {
103
108
auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index ] : seg_block.raw_outputs ()[index ];
104
109
auto cast_subgraph_value = is_input ? seg_block.inputs ()[index ] : seg_block.outputs ()[index ];
105
110
torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
106
111
auto g = seg_block.g ();
107
112
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
108
- if (cast_node) {
113
+ if (cast_node && !force_create_node ) {
109
114
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110
115
value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
111
116
if (!is_input) {
@@ -222,29 +227,39 @@ void getSegmentsOutputByRunning(
222
227
223
228
auto target_device = partitioning_info.getGPUDeviceString ();
224
229
225
- // auto int64 <=> int32 conversion
226
- if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info. truncate_long_and_double ) {
230
+ // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
231
+ if (seg_block.target () == SegmentedBlock::kTorch ) {
227
232
// First, check if there is Int64 input
228
- for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
229
- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
230
- auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
231
- at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
232
- if (t == at::kLong ) {
233
- // we add a cast operation to cast the type to Int64
234
- auto cast_node = createCastNode (seg_block, i, true , target_device);
235
- seg_block.g ()->prependNode (cast_node);
236
- seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
233
+ if (partitioning_info.truncate_long_and_double ) {
234
+ for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235
+ if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236
+ auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237
+ at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238
+ if (t == at::kLong ) {
239
+ // we add a cast operation to cast the type to Int64
240
+ auto cast_node = createCastNode (seg_block, i, true , target_device);
241
+ seg_block.g ()->prependNode (cast_node);
242
+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
243
+ }
237
244
}
238
245
}
239
246
}
247
+
240
248
for (size_t i = 0 ; i < seg_block.outputs ().size (); ++i) {
241
249
if (ivalues_maps[seg_block.raw_outputs ()[i]].isTensor ()) {
242
250
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
243
251
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
244
- if (t == at::kLong ) {
252
+
253
+ // If the input has type Long and truncation was requested, insert truncate
254
+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
245
255
auto cast_node = createCastNode (seg_block, i, false , target_device);
246
256
seg_block.g ()->appendNode (cast_node);
247
257
seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
258
+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
259
+ // If the input has type Byte and truncation was requested, insert Integer cast
260
+ auto cast_node = createCastNode (seg_block, i, false , target_device, /* force_create_node=*/ true );
261
+ seg_block.g ()->appendNode (cast_node);
262
+ seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
248
263
}
249
264
}
250
265
}
@@ -254,11 +269,13 @@ void getSegmentsOutputByRunning(
254
269
std::vector<std::vector<int64_t >> input_shapes;
255
270
std::vector<at::ScalarType> input_types;
256
271
for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
257
- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
272
+ auto current_input = seg_block.raw_inputs ()[i];
273
+
274
+ if (ivalues_maps[current_input].isTensor ()) {
258
275
// set the input_shape and data_type
259
276
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260
277
// shape inference
261
- auto cur_ivalue = ivalues_maps[seg_block. raw_inputs ()[i] ];
278
+ auto cur_ivalue = ivalues_maps[current_input ];
262
279
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
263
280
264
281
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
@@ -271,10 +288,16 @@ void getSegmentsOutputByRunning(
271
288
cur_ivalue = cur_ivalue.toTensor ().to (at::kFloat );
272
289
LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
273
290
}
291
+
274
292
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (cur_ivalue.toTensor ().dtype ());
275
293
if (dtype == c10::nullopt) {
276
294
TORCHTRT_THROW_ERROR (" Unsupported input data type " << cur_ivalue.toTensor ().dtype ());
295
+ } else if (dtype && dtype.value () == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs ) {
296
+ // Special case to ensure input IValues to TensorRT engine are not Int8 type if the
297
+ // model itself is not quantized
298
+ cur_ivalue = cur_ivalue.toTensor ().to (at::kInt );
277
299
}
300
+
278
301
if (cur_ivalue.toTensor ().sizes ().size () == 0 ) {
279
302
// handle Scalar types, which has sizes of []
280
303
input_shapes.push_back (util::toVec (util::toDims (c10::List<int64_t >({1 }))));
@@ -297,6 +320,7 @@ void runShapeAnalysis(
297
320
const ir::ShapeMode& shape_mode) {
298
321
// register every segment's input shape, and it's running output IValues
299
322
for (auto & seg_block : ctx->partitioned_blocks [block]) {
323
+ LOG_GRAPH (" Running shape analysis on block " << seg_block);
300
324
torch::jit::ConstantPooling (seg_block.g ());
301
325
getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings , shape_mode);
302
326
}
0 commit comments