Skip to content

feat: Add option to specify int64 as an Input dtype #1551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,26 @@ int AutocastLongInputs(
auto dtype = dtype_input->second.value();
// Currently, we do not autocast inputs for which the determined type is not long
if (dtype != at::kLong) {
LOG_DEBUG(
"Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype
<< " and not at::kLong");
continue;
}

LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);

// Generate cast node sending input tensors to the inferred or specified datatype (long)
torch::jit::Value *const_false, *cuda, *none_val;
if (num_autocasts == 0) {
// Only generate constants once and reuse for all autocasts
const_false = g->insertConstant(0);
const_false->setType(torch::jit::BoolType::get());
cuda = g->insertConstant(target_device_name);
cuda->setType(torch::jit::DeviceObjType::get());
none_val = g->insertNode(g->createNone())->output();
}

auto const_type = g->insertConstant(dtype);
auto const_false = g->insertConstant(0);
const_false->setType(torch::jit::BoolType::get());
auto cuda = g->insertConstant(target_device_name);
cuda->setType(torch::jit::DeviceObjType::get());
auto none_val = g->insertNode(g->createNone())->output();
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});

// Replace all uses of the original tensor with that of the casted tensor
Expand All @@ -73,12 +81,16 @@ int AutocastLongInputs(
}
}

LOG_WARNING(
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
<< "if the type does not match the determined required type for TRT. To disable this "
<< "automatic casting, specify an Input dtype other than Long");
LOG_GRAPH("Inserted " << num_autocasts << " autocasts");

LOG_GRAPH("Graph after Autocast: " << *g);
if (num_autocasts > 0) {
LOG_WARNING(
"Data types for input tensors have been modified by inserting "
<< "aten::to operations which cast INT64 inputs to INT32. "
<< "To disable this, please recompile using INT32 inputs");

LOG_GRAPH("Graph after Autocast: " << *g);
}

return num_autocasts;
}
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ nvinfer1::DataType toTRTDataType(DataType value) {
}
}

at::ScalarType toAtDataType(DataType value) {
at::ScalarType toAtenDataType(DataType value) {
switch (value) {
case DataType::kChar:
return at::kChar;
Expand Down Expand Up @@ -119,7 +119,7 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {

DataType::DataType(c10::ScalarType t) {
TORCHTRT_CHECK(
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool,
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool,
"Data type is unsupported (" << t << ")");
switch (t) {
case at::kHalf:
Expand All @@ -131,6 +131,9 @@ DataType::DataType(c10::ScalarType t) {
case at::kInt:
value = DataType::kInt;
break;
case at::kLong:
value = DataType::kLong;
break;
case at::kBool:
value = DataType::kBool;
break;
Expand Down Expand Up @@ -286,7 +289,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) {
i.min_shape,
i.opt_shape,
i.max_shape,
toAtDataType(i.dtype),
toAtenDataType(i.dtype),
toTRTTensorFormat(i.format),
!(i.dtype == DataType::kUnknown));
}
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
+ str(type(dtype))
)

def is_TRT_dtype(self) -> bool:
def is_trt_dtype(self) -> bool:
return self.dtype != _enums.dtype.long

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ nvinfer1::DataType toTRTDataType(DataType value) {
}
}

at::ScalarType toAtDataType(DataType value) {
at::ScalarType toAtenDataType(DataType value) {
switch (value) {
case DataType::kChar:
return at::kChar;
Expand Down Expand Up @@ -95,9 +95,9 @@ std::string to_str(TensorFormat value) {

core::ir::Input Input::toInternalInput() {
if (!input_is_dynamic) {
return core::ir::Input(opt, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
} else {
return core::ir::Input(min, opt, max, toAtDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
return core::ir::Input(min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype);
}
}

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace pyapi {
enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
std::string to_str(DataType value);
nvinfer1::DataType toTRTDataType(DataType value);
at::ScalarType toAtDataType(DataType value);
at::ScalarType toAtenDataType(DataType value);

enum class TensorFormat : int8_t { kContiguous, kChannelsLast };
std::string to_str(TensorFormat value);
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _parse_input_signature(input_signature: Any):
else input_signature
)

if not i.is_TRT_dtype():
if not i.is_trt_dtype():
raise TypeError(
"Using non-TRT input types with input_signature is not currently "
+ "supported. Please specify inputs individually to use "
Expand Down
44 changes: 44 additions & 0 deletions tests/cpp/test_collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,50 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor()));
}

TEST(CppAPITests, TestCollectionStandardTensorInputLongDtype) {
std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kLong);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
inputs.push_back(in0);

torch::jit::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
}
mod.eval();
mod.to(torch::kCUDA);

std::vector<torch::jit::IValue> inputs_;

for (auto in : inputs) {
inputs_.push_back(torch::jit::IValue(in.clone()));
}

auto out = mod.forward(inputs_);

std::vector<torch_tensorrt::Input> input_range;

// Specify Long input tensor type
input_range.push_back({in0.sizes(), torch::kLong});
input_range.push_back({in0.sizes(), torch::kLong});
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
compile_settings.min_block_size = 1;

// // FP32 execution with long and double truncation
compile_settings.enabled_precisions = {torch::kFloat};
compile_settings.truncate_long_and_double = true;
// // Compile module
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
auto trt_out = trt_mod.forward(inputs_);

ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
out.toTensor().to(torch::kFloat), trt_out.toTensor().to(torch::kFloat)));
}

TEST(CppAPITests, TestCollectionTupleInput) {
std::string path = "tests/modules/tuple_input_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
Expand Down