Skip to content

Commit 4282c06

Browse files
committed
fix: Improve autocast graph, provide CPP support
- Address review comments - Add cpp API testing and support - Improve length and efficiency of autocast graph - Improve messages displayed to user
1 parent 1e34332 commit 4282c06

File tree

7 files changed

+78
-19
lines changed

7 files changed

+78
-19
lines changed

core/lowering/lowering.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,26 @@ int AutocastLongInputs(
4848
auto dtype = dtype_input->second.value();
4949
// Currently, we do not autocast inputs for which the determined type is not long
5050
if (dtype != at::kLong) {
51+
LOG_DEBUG(
52+
"Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype
53+
<< " and not at::kLong");
5154
continue;
5255
}
5356

5457
LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);
5558

5659
// Generate cast node sending input tensors to the inferred or specified datatype (long)
60+
torch::jit::Value *const_false, *cuda, *none_val;
61+
if (num_autocasts == 0) {
62+
// Only generate constants once and reuse for all autocasts
63+
const_false = g->insertConstant(0);
64+
const_false->setType(torch::jit::BoolType::get());
65+
cuda = g->insertConstant(target_device_name);
66+
cuda->setType(torch::jit::DeviceObjType::get());
67+
none_val = g->insertNode(g->createNone())->output();
68+
}
69+
5770
auto const_type = g->insertConstant(dtype);
58-
auto const_false = g->insertConstant(0);
59-
const_false->setType(torch::jit::BoolType::get());
60-
auto cuda = g->insertConstant(target_device_name);
61-
cuda->setType(torch::jit::DeviceObjType::get());
62-
auto none_val = g->insertNode(g->createNone())->output();
6371
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
6472

6573
// Replace all uses of the original tensor with that of the casted tensor
@@ -73,12 +81,16 @@ int AutocastLongInputs(
7381
}
7482
}
7583

76-
LOG_WARNING(
77-
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
78-
<< "if the type does not match the determined required type for TRT. To disable this "
79-
<< "automatic casting, specify an Input dtype other than Long");
84+
LOG_GRAPH("Inserted " << num_autocasts << " autocasts");
8085

81-
LOG_GRAPH("Graph after Autocast: " << *g);
86+
if (num_autocasts > 0) {
87+
LOG_WARNING(
88+
"Data types for input tensors have been modified by inserting "
89+
<< "aten::to operations which cast INT64 inputs to INT32. "
90+
<< "To disable this, please recompile using INT32 inputs");
91+
92+
LOG_GRAPH("Graph after Autocast: " << *g);
93+
}
8294

8395
return num_autocasts;
8496
}

cpp/src/types.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ nvinfer1::DataType toTRTDataType(DataType value) {
8787
}
8888
}
8989

90-
at::ScalarType toAtDataType(DataType value) {
90+
at::ScalarType toAtenDataType(DataType value) {
9191
switch (value) {
9292
case DataType::kChar:
9393
return at::kChar;
@@ -119,7 +119,7 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
119119

120120
DataType::DataType(c10::ScalarType t) {
121121
TORCHTRT_CHECK(
122-
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool,
122+
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool,
123123
"Data type is unsupported (" << t << ")");
124124
switch (t) {
125125
case at::kHalf:
@@ -131,6 +131,9 @@ DataType::DataType(c10::ScalarType t) {
131131
case at::kInt:
132132
value = DataType::kInt;
133133
break;
134+
case at::kLong:
135+
value = DataType::kLong;
136+
break;
134137
case at::kBool:
135138
value = DataType::kBool;
136139
break;
@@ -286,7 +289,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) {
286289
i.min_shape,
287290
i.opt_shape,
288291
i.max_shape,
289-
toAtDataType(i.dtype),
292+
toAtenDataType(i.dtype),
290293
toTRTTensorFormat(i.format),
291294
!(i.dtype == DataType::kUnknown));
292295
}

py/torch_tensorrt/_Input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
244244
+ str(type(dtype))
245245
)
246246

247-
def is_TRT_dtype(self) -> bool:
247+
def is_trt_dtype(self) -> bool:
248248
return self.dtype != _enums.dtype.long
249249

250250
@staticmethod

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ nvinfer1::DataType toTRTDataType(DataType value) {
4444
}
4545
}
4646

47-
at::ScalarType toAtDataType(DataType value) {
47+
at::ScalarType toAtenDataType(DataType value) {
4848
switch (value) {
4949
case DataType::kChar:
5050
return at::kChar;
@@ -95,9 +95,9 @@ std::string to_str(TensorFormat value) {
9595

9696
core::ir::Input Input::toInternalInput() {
9797
if (!input_is_dynamic) {
98-
return core::ir::Input(opt, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
98+
return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
9999
} else {
100-
return core::ir::Input(min, opt, max, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
100+
return core::ir::Input(min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
101101
}
102102
}
103103

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace pyapi {
3030
enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
3131
std::string to_str(DataType value);
3232
nvinfer1::DataType toTRTDataType(DataType value);
33-
at::ScalarType toAtDataType(DataType value);
33+
at::ScalarType toAtenDataType(DataType value);
3434

3535
enum class TensorFormat : int8_t { kContiguous, kChannelsLast };
3636
std::string to_str(TensorFormat value);

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _parse_input_signature(input_signature: Any):
215215
else input_signature
216216
)
217217

218-
if not i.is_TRT_dtype():
218+
if not i.is_trt_dtype():
219219
raise TypeError(
220220
"Using non-TRT input types with input_signature is not currently "
221221
+ "supported. Please specify inputs individually to use "

tests/cpp/test_collections.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,50 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
4545
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor()));
4646
}
4747

48+
TEST(CppAPITests, TestCollectionStandardTensorInputLongDtype) {
49+
std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt";
50+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kLong);
51+
std::vector<at::Tensor> inputs;
52+
inputs.push_back(in0);
53+
inputs.push_back(in0);
54+
55+
torch::jit::Module mod;
56+
try {
57+
// Deserialize the ScriptModule from a file using torch::jit::load().
58+
mod = torch::jit::load(path);
59+
} catch (const c10::Error& e) {
60+
std::cerr << "error loading the model\n";
61+
}
62+
mod.eval();
63+
mod.to(torch::kCUDA);
64+
65+
std::vector<torch::jit::IValue> inputs_;
66+
67+
for (auto in : inputs) {
68+
inputs_.push_back(torch::jit::IValue(in.clone()));
69+
}
70+
71+
auto out = mod.forward(inputs_);
72+
73+
std::vector<torch_tensorrt::Input> input_range;
74+
75+
// Specify Long input tensor type
76+
input_range.push_back({in0.sizes(), torch::kLong});
77+
input_range.push_back({in0.sizes(), torch::kLong});
78+
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
79+
compile_settings.min_block_size = 1;
80+
81+
// // FP32 execution with long and double truncation
82+
compile_settings.enabled_precisions = {torch::kFloat};
83+
compile_settings.truncate_long_and_double = true;
84+
// // Compile module
85+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
86+
auto trt_out = trt_mod.forward(inputs_);
87+
88+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
89+
out.toTensor().to(torch::kFloat), trt_out.toTensor().to(torch::kFloat)));
90+
}
91+
4892
TEST(CppAPITests, TestCollectionTupleInput) {
4993
std::string path = "tests/modules/tuple_input_scripted.jit.pt";
5094
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

0 commit comments

Comments
 (0)