Skip to content

Commit 4cea990

Browse files
authored
Merge branch 'main' into dynamic_input_domain_phase_1
2 parents 7be7982 + 8adcacc commit 4cea990

File tree

159 files changed

+1330
-608
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+1330
-608
lines changed

.circleci/config.yml

Lines changed: 185 additions & 175 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
114114
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
115115

116116
- Bazel 5.2.0
117-
- Libtorch 1.12.1 (built with CUDA 11.6)
118-
- CUDA 11.6
119-
- cuDNN 8.4.1
120-
- TensorRT 8.4.3.1
117+
- Libtorch 2.0.0.dev20230103 (built with CUDA 11.7)
118+
- CUDA 11.7
119+
- cuDNN 8.5.0
120+
- TensorRT 8.5.1.7
121121

122122
## Prebuilt Binaries and Wheel files
123123

WORKSPACE

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_repository(
4141
new_local_repository(
4242
name = "cuda",
4343
build_file = "@//third_party/cuda:BUILD",
44-
path = "/usr/local/cuda-11.6/",
44+
path = "/usr/local/cuda-11.7/",
4545
)
4646

4747
new_local_repository(
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "b565c662435fd58ec295fa0791388ea52ad0f5fd33517b2d7c0fdcc91b6db531",
59+
sha256 = "59b8b5e1954a86d50b79c13f06398d385b200da13e37a08ecf31d3c62e5ca127",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-cxx11-abi-shared-with-deps-1.14.0.dev20221114%2Bcu116.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "fbb37446c33b05c1e26256c09f6ffb46cea1f6ff9ee2ad5b79b146d09023b0c1",
67+
sha256 = "e260fc7476be89d1650953e8643e9f7363845f5a52de4bab87ac0e619c1f6ad4",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu116/libtorch-shared-with-deps-1.14.0.dev20221114%2Bcu116.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230103%2Bcu117.zip"],
7070
)
7171

7272
# Download these tarballs manually from the NVIDIA website
@@ -76,20 +76,20 @@ http_archive(
7676
http_archive(
7777
name = "cudnn",
7878
build_file = "@//third_party/cudnn/archive:BUILD",
79-
sha256 = "ec96d2376d81fca42bdd3d4c3d705a99b29a065bab57f920561c763e29c67d01",
80-
strip_prefix = "cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive",
79+
sha256 = "5454a6fd94f008728caae9adad993c4e85ef36302e26bce43bea7d458a5e7b6d",
80+
strip_prefix = "cudnn-linux-x86_64-8.5.0.96_cuda11-archive",
8181
urls = [
82-
"https://developer.nvidia.com/compute/cudnn/secure/8.4.1/local_installers/11.6/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz",
82+
"https://developer.nvidia.com/compute/cudnn/secure/8.5.0/local_installers/11.7/cudnn-linux-x86_64-8.5.0.96_cuda11-archive.tar.xz",
8383
],
8484
)
8585

8686
http_archive(
8787
name = "tensorrt",
8888
build_file = "@//third_party/tensorrt/archive:BUILD",
89-
sha256 = "8d7c2085c1639dcc73875048c23598a8526ce3089136876e31d90258e49e4f61",
90-
strip_prefix = "TensorRT-8.4.3.1",
89+
sha256 = "39cc7f077057d1363794e8ff51c4cf21a5dbeccf1116b0020ba0dae0f3063076",
90+
strip_prefix = "TensorRT-8.5.1.7",
9191
urls = [
92-
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.3/tars/tensorrt-8.4.3.1.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz",
92+
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.5.1/tars/TensorRT-8.5.1.7.Linux.x86_64-gnu.cuda-11.8.cudnn8.6.tar.gz",
9393
],
9494
)
9595

core/compiler.cpp

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187187
return partitioning::stitch(&partitioning_ctx, block);
188188
}
189189

190-
void MapInputsAndDetermineDTypes(
190+
ir::TypeMap MapInputsAndDetermineDTypes(
191191
CompileSpec& cfg,
192192
std::shared_ptr<torch::jit::Graph>& g,
193193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197197
cfg.partitioning_info.collection_input_spec_map =
198198
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
199199

200+
ir::TypeMap inferred_dtypes;
200201
auto collection_inputs = ir::get_collection_inputs(g, static_params);
201202
LOG_DEBUG(
202203
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
218219
LOG_INFO(
219220
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
220221
<< in->debugName() << " has type " << est_type_opt[i].value());
221-
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
222+
spec[i].dtype = est_type_opt[i].value();
222223
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
223224
// If we cannot calculate the type and the user did not define the type, then default to FP32
224225
LOG_WARNING(
225226
"Cannot infer input type from calcuations in graph for input "
226227
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
227-
spec[i].dtype = nvinfer1::DataType::kFLOAT;
228+
spec[i].dtype = at::kFloat;
228229
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
229230
if (!est_type_opt[i]) {
230231
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
@@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
236237
auto warn_str = ss.str();
237238
LOG_WARNING(warn_str);
238239
// Overwrite type map with user settings
239-
first_use_type_map[in][i] = {
240-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
241-
242-
} else {
243-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
244-
est_type_opt[i].value()) {
245-
std::stringstream ss;
246-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
247-
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
248-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
249-
ss << est_type_opt[i].value() << std::endl;
250-
ss << "The compiler is going to use the user setting "
251-
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
252-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
253-
ss << "compatibility with PyTorch's data type convention is required.\n";
254-
ss << "If you do indeed see errors at runtime either:\n";
255-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
256-
ss << "- Disable partial compilation by setting require_full_compilation to True";
257-
auto warn_str = ss.str();
258-
LOG_WARNING(warn_str);
259-
// Overwrite type map with user settings
260-
first_use_type_map[in][i] = {
261-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
262-
}
240+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
241+
242+
} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
243+
std::stringstream ss;
244+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
245+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
246+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
247+
ss << est_type_opt[i].value() << std::endl;
248+
ss << "The compiler is going to use the user setting "
249+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
250+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
251+
ss << "compatibility with PyTorch's data type convention is required.\n";
252+
ss << "If you do indeed see errors at runtime either:\n";
253+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
254+
ss << "- Disable partial compilation by setting require_full_compilation to True";
255+
auto warn_str = ss.str();
256+
LOG_WARNING(warn_str);
257+
// Overwrite type map with user settings
258+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
263259
}
264260
} else {
265261
// The user defined the type so no changes are necessary
266262
}
263+
264+
// Insert entry for Value pointer and determined ScalarType
265+
inferred_dtypes.insert({in, {spec[i].dtype}});
267266
}
268267
}
269-
// }
268+
return inferred_dtypes;
270269
}
271270

272271
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
284283

285284
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
286285

286+
// Ensure none of the specified types are of acceptable input types incompatible with TRT
287+
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
288+
for (auto value_to_dtypes : first_use_types) {
289+
for (auto dtype : value_to_dtypes.second) {
290+
TORCHTRT_CHECK(
291+
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
292+
}
293+
}
294+
287295
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
288296

289297
return engine;
@@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307315
// Infer the type of an input from the weights of the calculation
308316
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
309317

310-
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
318+
// Extract map of IValue to DType
319+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
320+
321+
// Check whether any of the input types are Long
322+
bool user_requested_long = false;
323+
for (auto dtype : type_map) {
324+
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
325+
}
326+
327+
// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
328+
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
329+
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
330+
user_requested_long &= (casts_inserted > 0);
331+
}
332+
311333
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
312334
auto outputIsCollection = conversion::OutputIsCollection(g->block());
313-
if (cfg.partitioning_info.enabled &&
335+
if (cfg.partitioning_info.enabled && !user_requested_long &&
314336
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
315337
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
316338
!outputIsCollection) {
@@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
320342
if (cfg.partitioning_info.enabled &&
321343
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
322344
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
323-
outputIsCollection)) {
345+
outputIsCollection || user_requested_long)) {
324346
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
325347
new_g = graph_and_mapping.first;
326348
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> input
183183
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
184184
<< " in engine (conversion.AddInputs)");
185185

186-
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
186+
auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape);
187187
TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
188188
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));
189189

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,20 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
102102
}
103103

104104
auto w = Weights(ctx, args[1].unwrapToTensor());
105+
// TODO: Remove this when conv3d with kernel size=1 bug is fixed.
106+
// Github issue: https://github.com/pytorch/TensorRT/issues/1445
107+
bool is_kernel_size_one = true;
108+
bool is_3d_kernel = w.kernel_shape.nbDims == 3;
109+
for (int64_t i = 0; i < w.kernel_shape.nbDims; i++) {
110+
if (w.kernel_shape.d[i] != 1.0f) {
111+
is_kernel_size_one = false;
112+
}
113+
}
114+
if (is_kernel_size_one && is_3d_kernel) {
115+
LOG_WARNING(
116+
"Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \
117+
Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue.");
118+
}
105119
auto dims = in->getDimensions();
106120
auto orig_dims = dims;
107121
LOG_DEBUG("Input dims: " << orig_dims);

core/conversion/converters/impl/max.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
2222
if (dim < 0) {
2323
dim = selfDim.size() + dim;
2424
}
25+
bool int_input = self->getType() == nvinfer1::DataType::kINT32;
26+
if (int_input) {
27+
LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float");
28+
self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input");
29+
}
2530
uint32_t reduce_axes_mask = 1 << dim;
2631
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
2732
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
@@ -44,7 +49,10 @@ bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
4449
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
4550
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
4651
}
47-
52+
if (int_input) {
53+
LOG_DEBUG("Adding cast of topK layer output back to int32");
54+
out0 = castITensor(ctx, out0, nvinfer1::DataType::kINT32, util::node_info(n) + "_output");
55+
}
4856
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
4957
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
5058

@@ -59,6 +67,10 @@ bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvin
5967
if (dim < 0) {
6068
dim = selfDim.size() + dim;
6169
}
70+
if (self->getType() == nvinfer1::DataType::kINT32) {
71+
LOG_DEBUG("topk layer does not support int32 inputs, adding cast to float");
72+
self = castITensor(ctx, self, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_input");
73+
}
6274
uint32_t reduce_axes_mask = 1 << dim;
6375
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
6476
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);

core/conversion/converters/impl/select.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace impl {
1616
namespace {
1717

1818
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
19-
auto in = args[0].ITensor();
19+
auto in = args[0].ITensorOrFreeze(ctx);
2020
auto numOutputs = 1, numRemainder = 0;
2121
std::vector<int64_t> sizes;
2222

@@ -736,8 +736,22 @@ auto select_registrations TORCHTRT_UNUSED =
736736
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
737737
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
738738
auto condition = args[0].ITensorOrFreeze(ctx);
739+
auto condition_nbDims = condition->getDimensions().nbDims;
739740
auto x = args[1].ITensorOrFreeze(ctx);
741+
auto x_nbDims = x->getDimensions().nbDims;
740742
auto y = args[2].ITensorOrFreeze(ctx);
743+
auto y_nbDims = y->getDimensions().nbDims;
744+
745+
// Get maximum rank of all input tensors
746+
auto max_nbDims = std::max(condition_nbDims, std::max(x_nbDims, y_nbDims));
747+
748+
// TensorRT requires all inputs to Select layers to have the same rank, so for each
749+
// tensor input, ensure that its rank is equal to the maximum number of dimensions
750+
// If not, left-pad the tensor dimension with 1s until the max rank is achieved
751+
condition =
752+
addPadding(ctx, n, condition, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
753+
x = addPadding(ctx, n, x, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
754+
y = addPadding(ctx, n, y, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
741755

742756
auto layer = ctx->net->addSelect(*condition, *x, *y);
743757

core/ir/Input.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ bool valid_input_domain(std::vector<int64_t> domain) {
7575

7676
Input::Input(
7777
std::vector<int64_t> shape,
78-
nvinfer1::DataType dtype,
78+
at::ScalarType dtype,
7979
nvinfer1::TensorFormat format,
8080
bool dtype_is_user_defined,
8181
std::vector<int64_t> tensor_domain) {
@@ -89,10 +89,10 @@ Input::Input(
8989
input_shape = util::toDims(shape);
9090
input_is_dynamic = false;
9191

92-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
92+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
9393
this->dtype = dtype;
9494
TORCHTRT_CHECK(
95-
valid_dtype_format_combo(dtype, format),
95+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
9696
"Unsupported combination of dtype and tensor format: ("
9797
<< dtype << ", " << format
9898
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
@@ -109,7 +109,7 @@ Input::Input(
109109
std::vector<int64_t> min_shape,
110110
std::vector<int64_t> opt_shape,
111111
std::vector<int64_t> max_shape,
112-
nvinfer1::DataType dtype,
112+
at::ScalarType dtype,
113113
nvinfer1::TensorFormat format,
114114
bool dtype_is_user_defined,
115115
std::vector<int64_t> tensor_domain) {
@@ -148,10 +148,10 @@ Input::Input(
148148

149149
input_shape = util::toDims(dyn_shape);
150150

151-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
151+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
152152
this->dtype = dtype;
153153
TORCHTRT_CHECK(
154-
valid_dtype_format_combo(dtype, format),
154+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
155155
"Unsupported combination of dtype and tensor format: ("
156156
<< dtype << ", " << format
157157
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");

core/ir/ir.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ struct Input : torch::CustomClassHolder {
2929
Input(){};
3030
Input(
3131
std::vector<int64_t> shape,
32-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
32+
at::ScalarType dtype = at::kFloat,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
3434
bool dtype_is_user_defined = false,
3535
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
3636
Input(
3737
std::vector<int64_t> min_shape,
3838
std::vector<int64_t> opt_shape,
3939
std::vector<int64_t> max_shape,
40-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
40+
at::ScalarType dtype = at::kFloat,
4141
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
4242
bool dtype_is_user_defined = false,
4343
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
44+
4445
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4546

4647
bool input_is_dynamic = false;
@@ -50,7 +51,7 @@ struct Input : torch::CustomClassHolder {
5051
nvinfer1::Dims min;
5152
nvinfer1::Dims max;
5253
nvinfer1::Dims opt;
53-
nvinfer1::DataType dtype;
54+
at::ScalarType dtype;
5455
nvinfer1::TensorFormat format;
5556
int id;
5657
};

0 commit comments

Comments
 (0)