Skip to content

Commit 3d59933

Browse files
authored
feat: Add optional tensor domain argument to Input class (#1537)
1 parent 6ce3a44 commit 3d59933

14 files changed

+486
-14
lines changed

core/ir/Input.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
6969
}
7070
}
7171

72+
bool valid_input_domain(std::vector<double> domain) {
73+
return (domain.size() == 2) && (domain[0] < domain[1]);
74+
}
75+
7276
Input::Input(
7377
std::vector<int64_t> shape,
7478
at::ScalarType dtype,
7579
nvinfer1::TensorFormat format,
76-
bool dtype_is_user_defined) {
80+
bool dtype_is_user_defined,
81+
std::vector<double> tensor_domain) {
7782
if (shape.size() > 5) {
7883
LOG_WARNING("Verify that this dim size is accepted");
7984
}
@@ -93,6 +98,11 @@ Input::Input(
9398
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
9499
this->format = format;
95100
this->dtype_is_user_defined = dtype_is_user_defined;
101+
102+
TORCHTRT_CHECK(
103+
valid_input_domain(tensor_domain),
104+
"Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")");
105+
this->tensor_domain = tensor_domain;
96106
}
97107

98108
Input::Input(
@@ -101,7 +111,8 @@ Input::Input(
101111
std::vector<int64_t> max_shape,
102112
at::ScalarType dtype,
103113
nvinfer1::TensorFormat format,
104-
bool dtype_is_user_defined) {
114+
bool dtype_is_user_defined,
115+
std::vector<double> tensor_domain) {
105116
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
106117
LOG_WARNING("Verify that this dim size is accepted");
107118
}
@@ -146,6 +157,10 @@ Input::Input(
146157
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
147158
this->format = format;
148159
this->dtype_is_user_defined = dtype_is_user_defined;
160+
TORCHTRT_CHECK(
161+
valid_input_domain(tensor_domain),
162+
"Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")");
163+
this->tensor_domain = tensor_domain;
149164
}
150165

151166
std::ostream& operator<<(std::ostream& os, const Input& input) {

core/ir/ir.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,22 @@ struct Input : torch::CustomClassHolder {
3131
std::vector<int64_t> shape,
3232
at::ScalarType dtype = at::kFloat,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
34-
bool dtype_is_user_defined = false);
34+
bool dtype_is_user_defined = false,
35+
std::vector<double> tensor_domain = std::vector<double>{0, 2});
3536
Input(
3637
std::vector<int64_t> min_shape,
3738
std::vector<int64_t> opt_shape,
3839
std::vector<int64_t> max_shape,
3940
at::ScalarType dtype = at::kFloat,
4041
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
41-
bool dtype_is_used_defined = false);
42+
bool dtype_is_user_defined = false,
43+
std::vector<double> tensor_domain = std::vector<double>{0, 2});
4244

4345
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4446

4547
bool input_is_dynamic = false;
4648
bool dtype_is_user_defined = false;
49+
std::vector<double> tensor_domain;
4750
nvinfer1::Dims input_shape;
4851
nvinfer1::Dims min;
4952
nvinfer1::Dims max;

core/partitioning/shape_analysis.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ at::Tensor generateSingleInput(
2626
}
2727

2828
// Initialize min and max ranges for random number selection
29-
int LoValIncl = 0;
30-
int HiValExcl = 2;
29+
double LoValIncl = input.tensor_domain[0];
30+
double HiValExcl = input.tensor_domain[1];
3131

3232
auto type = at::kFloat;
3333
if (type_opt) {
@@ -36,6 +36,10 @@ at::Tensor generateSingleInput(
3636
LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32");
3737
}
3838

39+
LOG_DEBUG(
40+
"Using the Range: [" << LoValIncl << ", " << HiValExcl
41+
<< ") as a random range for shape analysis on input with data type " << type);
42+
3943
// Make the value range for input tensor a uniform (float) distribution
4044
// over [LoValIncl, HiValExcl), then cast to the desired dtype
4145
auto in = ((HiValExcl - LoValIncl) * at::rand(util::toVec(input_shape), {at::kCUDA}) + LoValIncl).to(type);

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ struct Input : torch::CustomClassHolder {
381381
DataType dtype;
382382
/// Expected tensor format for the input
383383
TensorFormat format;
384+
/// Expected allowed domain for tensor input
385+
std::vector<double> tensor_domain;
384386

385387
Input() {}
386388
/**
@@ -394,6 +396,22 @@ struct Input : torch::CustomClassHolder {
394396
*/
395397
TORCHTRT_API Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
396398

399+
/**
400+
* @brief Construct a new Input spec object for static input size from
401+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
402+
* allow the user to configure expected input shape tensor format
403+
* dtype (Expected data type for the input) defaults to PyTorch
404+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
405+
*
406+
* @param shape Input tensor shape
407+
* @param tensor_domain Allowed range for tensor inputs [low, high)
408+
* @param format Expected tensor format for the input (Defaults to contiguous)
409+
*/
410+
TORCHTRT_API Input(
411+
std::vector<int64_t> shape,
412+
std::vector<double> tensor_domain,
413+
TensorFormat format = TensorFormat::kContiguous);
414+
397415
/**
398416
* @brief Construct a new Input spec object for static input size from
399417
* vector, optional arguments allow the user to configure expected input shape
@@ -406,6 +424,23 @@ struct Input : torch::CustomClassHolder {
406424
*/
407425
TORCHTRT_API Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
408426

427+
/**
428+
* @brief Construct a new Input spec object for static input size from
429+
* vector, optional arguments allow the user to configure expected input shape
430+
* tensor format
431+
*
432+
* @param shape Input tensor shape
433+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
434+
* calculation if detectable else Float32)
435+
* @param tensor_domain Allowed range for tensor inputs [low, high)
436+
* @param format Expected tensor format for the input (Defaults to contiguous)
437+
*/
438+
TORCHTRT_API Input(
439+
std::vector<int64_t> shape,
440+
DataType dtype,
441+
std::vector<double> tensor_domain,
442+
TensorFormat format = TensorFormat::kContiguous);
443+
409444
/**
410445
* @brief Construct a new Input spec object for static input size from
411446
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -418,6 +453,22 @@ struct Input : torch::CustomClassHolder {
418453
*/
419454
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
420455

456+
/**
457+
* @brief Construct a new Input spec object for static input size from
458+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
459+
* allow the user to configure expected input shape tensor format
460+
* dtype (Expected data type for the input) defaults to PyTorch
461+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
462+
*
463+
* @param shape Input tensor shape
464+
* @param tensor_domain Allowed range for tensor inputs [low, high)
465+
* @param format Expected tensor format for the input (Defaults to contiguous)
466+
*/
467+
TORCHTRT_API Input(
468+
c10::ArrayRef<int64_t> shape,
469+
std::vector<double> tensor_domain,
470+
TensorFormat format = TensorFormat::kContiguous);
471+
421472
/**
422473
* @brief Construct a new Input spec object for static input size from
423474
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -430,6 +481,23 @@ struct Input : torch::CustomClassHolder {
430481
*/
431482
TORCHTRT_API Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
432483

484+
/**
485+
* @brief Construct a new Input spec object for static input size from
486+
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
487+
* allow the user to configure expected input shape tensor format
488+
*
489+
* @param shape Input tensor shape
490+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
491+
* calculation if detectable else Float32)
492+
* @param tensor_domain Allowed range for tensor inputs [low, high)
493+
* @param format Expected tensor format for the input (Defaults to contiguous)
494+
*/
495+
TORCHTRT_API Input(
496+
c10::ArrayRef<int64_t> shape,
497+
DataType dtype,
498+
std::vector<double> tensor_domain,
499+
TensorFormat format = TensorFormat::kContiguous);
500+
433501
/**
434502
* @brief Construct a new Input spec object dynamic input size from
435503
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -446,6 +514,24 @@ struct Input : torch::CustomClassHolder {
446514
std::vector<int64_t> opt_shape,
447515
std::vector<int64_t> max_shape,
448516
TensorFormat format = TensorFormat::kContiguous);
517+
/**
518+
* @brief Construct a new Input spec object dynamic input size from
519+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
520+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
521+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
522+
*
523+
* @param min_shape Minimum shape for input tensor
524+
* @param opt_shape Target optimization shape for input tensor
525+
* @param max_shape Maximum acceptible shape for input tensor
526+
* @param tensor_domain Allowed range for tensor inputs [low, high)
527+
* @param format Expected tensor format for the input (Defaults to contiguous)
528+
*/
529+
TORCHTRT_API Input(
530+
std::vector<int64_t> min_shape,
531+
std::vector<int64_t> opt_shape,
532+
std::vector<int64_t> max_shape,
533+
std::vector<double> tensor_domain,
534+
TensorFormat format = TensorFormat::kContiguous);
449535

450536
/**
451537
* @brief Construct a new Input spec object for a dynamic input size from vectors
@@ -466,6 +552,44 @@ struct Input : torch::CustomClassHolder {
466552
DataType dtype,
467553
TensorFormat format = TensorFormat::kContiguous);
468554

555+
/**
556+
* @brief Construct a new Input spec object for a dynamic input size from vectors
557+
* for minimum shape, optimal shape, and max shape supported sizes optional arguments
558+
* allow the user to configure expected input shape tensor format
559+
*
560+
* @param min_shape Minimum shape for input tensor
561+
* @param opt_shape Target optimization shape for input tensor
562+
* @param max_shape Maximum acceptible shape for input tensor
563+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
564+
* calculation if detectable else Float32)
565+
* @param tensor_domain Allowed range for tensor inputs [low, high)
566+
* @param format Expected tensor format for the input (Defaults to contiguous)
567+
*/
568+
TORCHTRT_API Input(
569+
std::vector<int64_t> min_shape,
570+
std::vector<int64_t> opt_shape,
571+
std::vector<int64_t> max_shape,
572+
DataType dtype,
573+
std::vector<double> tensor_domain,
574+
TensorFormat format = TensorFormat::kContiguous);
575+
576+
/**
577+
* @brief Construct a new Input spec object dynamic input size from
578+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
579+
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
580+
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
581+
*
582+
* @param min_shape Minimum shape for input tensor
583+
* @param opt_shape Target optimization shape for input tensor
584+
* @param max_shape Maximum acceptible shape for input tensor
585+
* @param format Expected tensor format for the input (Defaults to contiguous)
586+
*/
587+
TORCHTRT_API Input(
588+
c10::ArrayRef<int64_t> min_shape,
589+
c10::ArrayRef<int64_t> opt_shape,
590+
c10::ArrayRef<int64_t> max_shape,
591+
TensorFormat format = TensorFormat::kContiguous);
592+
469593
/**
470594
* @brief Construct a new Input spec object dynamic input size from
471595
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -475,12 +599,33 @@ struct Input : torch::CustomClassHolder {
475599
* @param min_shape Minimum shape for input tensor
476600
* @param opt_shape Target optimization shape for input tensor
477601
* @param max_shape Maximum acceptible shape for input tensor
602+
* @param tensor_domain Allowed range for tensor inputs [low, high)
478603
* @param format Expected tensor format for the input (Defaults to contiguous)
479604
*/
480605
TORCHTRT_API Input(
481606
c10::ArrayRef<int64_t> min_shape,
482607
c10::ArrayRef<int64_t> opt_shape,
483608
c10::ArrayRef<int64_t> max_shape,
609+
std::vector<double> tensor_domain,
610+
TensorFormat format = TensorFormat::kContiguous);
611+
612+
/**
613+
* @brief Construct a new Input spec object dynamic input size from
614+
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
615+
* supported sizes
616+
*
617+
* @param min_shape Minimum shape for input tensor
618+
* @param opt_shape Target optimization shape for input tensor
619+
* @param max_shape Maximum acceptible shape for input tensor
620+
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
621+
* calculation if detectable else Float32)
622+
* @param format Expected tensor format for the input (Defaults to contiguous)
623+
*/
624+
TORCHTRT_API Input(
625+
c10::ArrayRef<int64_t> min_shape,
626+
c10::ArrayRef<int64_t> opt_shape,
627+
c10::ArrayRef<int64_t> max_shape,
628+
DataType dtype,
484629
TensorFormat format = TensorFormat::kContiguous);
485630

486631
/**
@@ -493,13 +638,15 @@ struct Input : torch::CustomClassHolder {
493638
* @param max_shape Maximum acceptible shape for input tensor
494639
* @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor
495640
* calculation if detectable else Float32)
641+
* @param tensor_domain Allowed range for tensor inputs [low, high)
496642
* @param format Expected tensor format for the input (Defaults to contiguous)
497643
*/
498644
TORCHTRT_API Input(
499645
c10::ArrayRef<int64_t> min_shape,
500646
c10::ArrayRef<int64_t> opt_shape,
501647
c10::ArrayRef<int64_t> max_shape,
502648
DataType dtype,
649+
std::vector<double> tensor_domain,
503650
TensorFormat format = TensorFormat::kContiguous);
504651

505652
/**

0 commit comments

Comments
 (0)