Skip to content

Commit 7be7982

Browse files
committed
feat: Add optional tensor domain arg to Input class
- Add support for specifying one, both, or neither integer bound (in Python) on Tensor domain used for generating a sample tensor in shape analysis - Add bounds checking to avoid int32_t overflow issues - Add validation for inputs to ensure high bound exceeds low bound and inputs are strictly integers - Use [0, 2) as the default bound, which works well as a random-number selection range in cases of Bool, Int, and Float types - Add validation in C++ to ensure provided domain arguments have the correct number of elements and ordering - Add functionality to print domains as part of both partitioning and initialization, for debugging purposes - Add hooks in pybind to capture specified input domain argument and parse values for mirrored internal representation - Add new Input constructors with defaults to accommodate new `tensor_domain` argument - Add pybind and torchbind get/set fields for C++ and Python API compatibility - Add C++ and Python collections test cases to ensure domain specification is functional across APIs
1 parent 063be0d commit 7be7982

File tree

13 files changed

+488
-14
lines changed

13 files changed

+488
-14
lines changed

core/ir/Input.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
6969
}
7070
}
7171

72+
bool valid_input_domain(std::vector<int64_t> domain) {
73+
return (domain.size() == 2) && (domain[0] < domain[1]);
74+
}
75+
7276
Input::Input(
7377
std::vector<int64_t> shape,
7478
nvinfer1::DataType dtype,
7579
nvinfer1::TensorFormat format,
76-
bool dtype_is_user_defined) {
80+
bool dtype_is_user_defined,
81+
std::vector<int64_t> tensor_domain) {
7782
if (shape.size() > 5) {
7883
LOG_WARNING("Verify that this dim size is accepted");
7984
}
@@ -93,6 +98,11 @@ Input::Input(
9398
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
9499
this->format = format;
95100
this->dtype_is_user_defined = dtype_is_user_defined;
101+
102+
TORCHTRT_CHECK(
103+
valid_input_domain(tensor_domain),
104+
"Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")");
105+
this->tensor_domain = tensor_domain;
96106
}
97107

98108
Input::Input(
@@ -101,7 +111,8 @@ Input::Input(
101111
std::vector<int64_t> max_shape,
102112
nvinfer1::DataType dtype,
103113
nvinfer1::TensorFormat format,
104-
bool dtype_is_user_defined) {
114+
bool dtype_is_user_defined,
115+
std::vector<int64_t> tensor_domain) {
105116
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
106117
LOG_WARNING("Verify that this dim size is accepted");
107118
}
@@ -146,6 +157,10 @@ Input::Input(
146157
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
147158
this->format = format;
148159
this->dtype_is_user_defined = dtype_is_user_defined;
160+
TORCHTRT_CHECK(
161+
valid_input_domain(tensor_domain),
162+
"Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")");
163+
this->tensor_domain = tensor_domain;
149164
}
150165

151166
std::ostream& operator<<(std::ostream& os, const Input& input) {

core/ir/ir.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,21 @@ struct Input : torch::CustomClassHolder {
3131
std::vector<int64_t> shape,
3232
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
34-
bool dtype_is_user_defined = false);
34+
bool dtype_is_user_defined = false,
35+
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
3536
Input(
3637
std::vector<int64_t> min_shape,
3738
std::vector<int64_t> opt_shape,
3839
std::vector<int64_t> max_shape,
3940
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
4041
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
41-
bool dtype_is_used_defined = false);
42+
bool dtype_is_user_defined = false,
43+
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
4244
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4345

4446
bool input_is_dynamic = false;
4547
bool dtype_is_user_defined = false;
48+
std::vector<int64_t> tensor_domain;
4649
nvinfer1::Dims input_shape;
4750
nvinfer1::Dims min;
4851
nvinfer1::Dims max;

core/partitioning/shape_analysis.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ at::Tensor generateSingleInput(
2626
}
2727

2828
// Initialize min and max ranges for random number selection
29-
int LoValIncl = 0;
30-
int HiValExcl = 2;
29+
int LoValIncl = input.tensor_domain[0];
30+
int HiValExcl = input.tensor_domain[1];
3131

3232
auto type = at::kFloat;
3333
if (type_opt) {
@@ -36,6 +36,10 @@ at::Tensor generateSingleInput(
3636
LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32");
3737
}
3838

39+
LOG_DEBUG(
40+
"Using the Range: [" << LoValIncl << ", " << HiValExcl
41+
<< ") as a random range for shape analysis on input with data type " << type);
42+
3943
// Make the value range for input tensor a uniform (float) distribution
4044
// over [LoValIncl, HiValExcl), then cast to the desired dtype
4145
auto in = ((HiValExcl - LoValIncl) * at::rand(util::toVec(input_shape), {at::kCUDA}) + LoValIncl).to(type);

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ struct Input : torch::CustomClassHolder {
379379
DataType dtype;
380380
/// Expected tensor format for the input
381381
TensorFormat format;
382+
/// Expected allowed domain for tensor input
383+
std::vector<int64_t> tensor_domain;
382384

383385
Input() {}
384386
/**
@@ -392,6 +394,22 @@ struct Input : torch::CustomClassHolder {
392394
*/
393395
TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
394396

397+
/**
398+
* @brief Construct a new Input spec object for static input size from
399+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
400+
* allow the user to configure expected input shape tensor format
401+
* dtype (Expected data type for the input) defaults to PyTorch
402+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
403+
*
404+
* @param shape Input tensor shape
405+
* @param tensor_domain Allowed range for tensor inputs [low, high)
406+
* @param format Expected tensor format for the input (Defaults to contiguous)
407+
*/
408+
TORCHTRT_API Input(
409+
std::vector<int64_t> shape,
410+
std::vector<int64_t> tensor_domain,
411+
TensorFormat format = TensorFormat::kContiguous);
412+
395413
/**
396414
* @brief Construct a new Input spec object for static input size from
397415
* vector, optional arguments allow the user to configure expected input shape
@@ -404,6 +422,23 @@ struct Input : torch::CustomClassHolder {
404422
*/
405423
TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
406424

425+
/**
426+
* @brief Construct a new Input spec object for static input size from
427+
* vector, optional arguments allow the user to configure expected input shape
428+
* tensor format
429+
*
430+
* @param shape Input tensor shape
431+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
432+
* calculation if detectable else Float32)
433+
* @param tensor_domain Allowed range for tensor inputs [low, high)
434+
* @param format Expected tensor format for the input (Defaults to contiguous)
435+
*/
436+
TORCHTRT_API Input(
437+
std::vector<int64_t> shape,
438+
DataType dtype,
439+
std::vector<int64_t> tensor_domain,
440+
TensorFormat format = TensorFormat::kContiguous);
441+
407442
/**
408443
* @brief Construct a new Input spec object for static input size from
409444
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -416,6 +451,22 @@ struct Input : torch::CustomClassHolder {
416451
*/
417452
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
418453

454+
/**
455+
* @brief Construct a new Input spec object for static input size from
456+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
457+
* allow the user to configure expected input shape tensor format
458+
* dtype (Expected data type for the input) defaults to PyTorch
459+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
460+
*
461+
* @param shape Input tensor shape
462+
* @param tensor_domain Allowed range for tensor inputs [low, high)
463+
* @param format Expected tensor format for the input (Defaults to contiguous)
464+
*/
465+
TORCHTRT_API Input(
466+
c10::ArrayRef<int64_t> shape,
467+
std::vector<int64_t> tensor_domain,
468+
TensorFormat format = TensorFormat::kContiguous);
469+
419470
/**
420471
* @brief Construct a new Input spec object for static input size from
421472
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -428,6 +479,23 @@ struct Input : torch::CustomClassHolder {
428479
*/
429480
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
430481

482+
/**
483+
* @brief Construct a new Input spec object for static input size from
484+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
485+
* allow the user to configure expected input shape tensor format
486+
*
487+
* @param shape Input tensor shape
488+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
489+
* calculation if detectable else Float32)
490+
* @param tensor_domain Allowed range for tensor inputs [low, high)
491+
* @param format Expected tensor format for the input (Defaults to contiguous)
492+
*/
493+
TORCHTRT_API Input(
494+
c10::ArrayRef<int64_t> shape,
495+
DataType dtype,
496+
std::vector<int64_t> tensor_domain,
497+
TensorFormat format = TensorFormat::kContiguous);
498+
431499
/**
432500
* @brief Construct a new Input spec object dynamic input size from
433501
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -444,6 +512,24 @@ struct Input : torch::CustomClassHolder {
444512
std::vector<int64_t> opt_shape,
445513
std::vector<int64_t> max_shape,
446514
TensorFormat format = TensorFormat::kContiguous);
515+
/**
516+
* @brief Construct a new Input spec object dynamic input size from
517+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
518+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
519+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
520+
*
521+
* @param min_shape Minimum shape for input tensor
522+
* @param opt_shape Target optimization shape for input tensor
523+
* @param max_shape Maximum acceptible shape for input tensor
524+
* @param tensor_domain Allowed range for tensor inputs [low, high)
525+
* @param format Expected tensor format for the input (Defaults to contiguous)
526+
*/
527+
TORCHTRT_API Input(
528+
std::vector<int64_t> min_shape,
529+
std::vector<int64_t> opt_shape,
530+
std::vector<int64_t> max_shape,
531+
std::vector<int64_t> tensor_domain,
532+
TensorFormat format = TensorFormat::kContiguous);
447533

448534
/**
449535
* @brief Construct a new Input spec object for a dynamic input size from vectors
@@ -464,6 +550,44 @@ struct Input : torch::CustomClassHolder {
464550
DataType dtype,
465551
TensorFormat format = TensorFormat::kContiguous);
466552

553+
/**
554+
* @brief Construct a new Input spec object for a dynamic input size from vectors
555+
* for minimum shape, optimal shape, and max shape supported sizes optional arguments
556+
* allow the user to configure expected input shape tensor format
557+
*
558+
* @param min_shape Minimum shape for input tensor
559+
* @param opt_shape Target optimization shape for input tensor
560+
* @param max_shape Maximum acceptible shape for input tensor
561+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
562+
* calculation if detectable else Float32)
563+
* @param tensor_domain Allowed range for tensor inputs [low, high)
564+
* @param format Expected tensor format for the input (Defaults to contiguous)
565+
*/
566+
TORCHTRT_API Input(
567+
std::vector<int64_t> min_shape,
568+
std::vector<int64_t> opt_shape,
569+
std::vector<int64_t> max_shape,
570+
DataType dtype,
571+
std::vector<int64_t> tensor_domain,
572+
TensorFormat format = TensorFormat::kContiguous);
573+
574+
/**
575+
* @brief Construct a new Input spec object dynamic input size from
576+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
577+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
578+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
579+
*
580+
* @param min_shape Minimum shape for input tensor
581+
* @param opt_shape Target optimization shape for input tensor
582+
* @param max_shape Maximum acceptible shape for input tensor
583+
* @param format Expected tensor format for the input (Defaults to contiguous)
584+
*/
585+
TORCHTRT_API Input(
586+
c10::ArrayRef<int64_t> min_shape,
587+
c10::ArrayRef<int64_t> opt_shape,
588+
c10::ArrayRef<int64_t> max_shape,
589+
TensorFormat format = TensorFormat::kContiguous);
590+
467591
/**
468592
* @brief Construct a new Input spec object dynamic input size from
469593
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -473,12 +597,33 @@ struct Input : torch::CustomClassHolder {
473597
* @param min_shape Minimum shape for input tensor
474598
* @param opt_shape Target optimization shape for input tensor
475599
* @param max_shape Maximum acceptible shape for input tensor
600+
* @param tensor_domain Allowed range for tensor inputs [low, high)
476601
* @param format Expected tensor format for the input (Defaults to contiguous)
477602
*/
478603
TORCHTRT_API Input(
479604
c10::ArrayRef<int64_t> min_shape,
480605
c10::ArrayRef<int64_t> opt_shape,
481606
c10::ArrayRef<int64_t> max_shape,
607+
std::vector<int64_t> tensor_domain,
608+
TensorFormat format = TensorFormat::kContiguous);
609+
610+
/**
611+
* @brief Construct a new Input spec object dynamic input size from
612+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
613+
* supported sizes
614+
*
615+
* @param min_shape Minimum shape for input tensor
616+
* @param opt_shape Target optimization shape for input tensor
617+
* @param max_shape Maximum acceptible shape for input tensor
618+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
619+
* calculation if detectable else Float32)
620+
* @param format Expected tensor format for the input (Defaults to contiguous)
621+
*/
622+
TORCHTRT_API Input(
623+
c10::ArrayRef<int64_t> min_shape,
624+
c10::ArrayRef<int64_t> opt_shape,
625+
c10::ArrayRef<int64_t> max_shape,
626+
DataType dtype,
482627
TensorFormat format = TensorFormat::kContiguous);
483628

484629
/**
@@ -491,13 +636,15 @@ struct Input : torch::CustomClassHolder {
491636
* @param max_shape Maximum acceptible shape for input tensor
492637
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
493638
* calculation if detectable else Float32)
639+
* @param tensor_domain Allowed range for tensor inputs [low, high)
494640
* @param format Expected tensor format for the input (Defaults to contiguous)
495641
*/
496642
TORCHTRT_API Input(
497643
c10::ArrayRef<int64_t> min_shape,
498644
c10::ArrayRef<int64_t> opt_shape,
499645
c10::ArrayRef<int64_t> max_shape,
500646
DataType dtype,
647+
std::vector<int64_t> tensor_domain,
501648
TensorFormat format = TensorFormat::kContiguous);
502649

503650
/**

0 commit comments

Comments
 (0)