Skip to content

Commit ce84566

Browse files
committed
fix: Refactor domain to take float Inputs
- Update Python API to accept a full domain definition or None, without intermediates - Update cpp and Python APIs to include float support - Update data structure for tensor domain object to _Float64 - Add tests to ensure tensor domain usage in shape analysis is covered via test cases
1 parent 4cea990 commit ce84566

File tree

11 files changed

+90
-92
lines changed

11 files changed

+90
-92
lines changed

core/ir/Input.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
6969
}
7070
}
7171

72-
bool valid_input_domain(std::vector<int64_t> domain) {
72+
bool valid_input_domain(std::vector<double> domain) {
7373
return (domain.size() == 2) && (domain[0] < domain[1]);
7474
}
7575

@@ -78,7 +78,7 @@ Input::Input(
7878
at::ScalarType dtype,
7979
nvinfer1::TensorFormat format,
8080
bool dtype_is_user_defined,
81-
std::vector<int64_t> tensor_domain) {
81+
std::vector<double> tensor_domain) {
8282
if (shape.size() > 5) {
8383
LOG_WARNING("Verify that this dim size is accepted");
8484
}
@@ -112,7 +112,7 @@ Input::Input(
112112
at::ScalarType dtype,
113113
nvinfer1::TensorFormat format,
114114
bool dtype_is_user_defined,
115-
std::vector<int64_t> tensor_domain) {
115+
std::vector<double> tensor_domain) {
116116
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
117117
LOG_WARNING("Verify that this dim size is accepted");
118118
}

core/ir/ir.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@ struct Input : torch::CustomClassHolder {
3232
at::ScalarType dtype = at::kFloat,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
3434
bool dtype_is_user_defined = false,
35-
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
35+
std::vector<double> tensor_domain = std::vector<double>{0, 2});
3636
Input(
3737
std::vector<int64_t> min_shape,
3838
std::vector<int64_t> opt_shape,
3939
std::vector<int64_t> max_shape,
4040
at::ScalarType dtype = at::kFloat,
4141
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
4242
bool dtype_is_user_defined = false,
43-
std::vector<int64_t> tensor_domain = std::vector<int64_t>{0, 2});
43+
std::vector<double> tensor_domain = std::vector<double>{0, 2});
4444

4545
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4646

4747
bool input_is_dynamic = false;
4848
bool dtype_is_user_defined = false;
49-
std::vector<int64_t> tensor_domain;
49+
std::vector<double> tensor_domain;
5050
nvinfer1::Dims input_shape;
5151
nvinfer1::Dims min;
5252
nvinfer1::Dims max;

core/partitioning/shape_analysis.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ at::Tensor generateSingleInput(
2626
}
2727

2828
// Initialize min and max ranges for random number selection
29-
int LoValIncl = input.tensor_domain[0];
30-
int HiValExcl = input.tensor_domain[1];
29+
double LoValIncl = input.tensor_domain[0];
30+
double HiValExcl = input.tensor_domain[1];
3131

3232
auto type = at::kFloat;
3333
if (type_opt) {

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ struct Input : torch::CustomClassHolder {
382382
/// Expected tensor format for the input
383383
TensorFormat format;
384384
/// Expected allowed domain for tensor input
385-
std::vector<int64_t> tensor_domain;
385+
std::vector<double> tensor_domain;
386386

387387
Input() {}
388388
/**
@@ -409,7 +409,7 @@ struct Input : torch::CustomClassHolder {
409409
*/
410410
TORCHTRT_API Input(
411411
std::vector<int64_t> shape,
412-
std::vector<int64_t> tensor_domain,
412+
std::vector<double> tensor_domain,
413413
TensorFormat format = TensorFormat::kContiguous);
414414

415415
/**
@@ -438,7 +438,7 @@ struct Input : torch::CustomClassHolder {
438438
TORCHTRT_API Input(
439439
std::vector<int64_t> shape,
440440
DataType dtype,
441-
std::vector<int64_t> tensor_domain,
441+
std::vector<double> tensor_domain,
442442
TensorFormat format = TensorFormat::kContiguous);
443443

444444
/**
@@ -466,7 +466,7 @@ struct Input : torch::CustomClassHolder {
466466
*/
467467
TORCHTRT_API Input(
468468
c10::ArrayRef<int64_t> shape,
469-
std::vector<int64_t> tensor_domain,
469+
std::vector<double> tensor_domain,
470470
TensorFormat format = TensorFormat::kContiguous);
471471

472472
/**
@@ -495,7 +495,7 @@ struct Input : torch::CustomClassHolder {
495495
TORCHTRT_API Input(
496496
c10::ArrayRef<int64_t> shape,
497497
DataType dtype,
498-
std::vector<int64_t> tensor_domain,
498+
std::vector<double> tensor_domain,
499499
TensorFormat format = TensorFormat::kContiguous);
500500

501501
/**
@@ -530,7 +530,7 @@ struct Input : torch::CustomClassHolder {
530530
std::vector<int64_t> min_shape,
531531
std::vector<int64_t> opt_shape,
532532
std::vector<int64_t> max_shape,
533-
std::vector<int64_t> tensor_domain,
533+
std::vector<double> tensor_domain,
534534
TensorFormat format = TensorFormat::kContiguous);
535535

536536
/**
@@ -570,7 +570,7 @@ struct Input : torch::CustomClassHolder {
570570
std::vector<int64_t> opt_shape,
571571
std::vector<int64_t> max_shape,
572572
DataType dtype,
573-
std::vector<int64_t> tensor_domain,
573+
std::vector<double> tensor_domain,
574574
TensorFormat format = TensorFormat::kContiguous);
575575

576576
/**
@@ -606,7 +606,7 @@ struct Input : torch::CustomClassHolder {
606606
c10::ArrayRef<int64_t> min_shape,
607607
c10::ArrayRef<int64_t> opt_shape,
608608
c10::ArrayRef<int64_t> max_shape,
609-
std::vector<int64_t> tensor_domain,
609+
std::vector<double> tensor_domain,
610610
TensorFormat format = TensorFormat::kContiguous);
611611

612612
/**
@@ -646,7 +646,7 @@ struct Input : torch::CustomClassHolder {
646646
c10::ArrayRef<int64_t> opt_shape,
647647
c10::ArrayRef<int64_t> max_shape,
648648
DataType dtype,
649-
std::vector<int64_t> tensor_domain,
649+
std::vector<double> tensor_domain,
650650
TensorFormat format = TensorFormat::kContiguous);
651651

652652
/**

cpp/src/types.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ Input::Input(std::vector<int64_t> shape, TensorFormat format) {
173173
this->dtype = DataType::kUnknown;
174174
this->format = format;
175175
this->input_is_dynamic = false;
176-
this->tensor_domain = std::vector<int64_t>{0, 2};
176+
this->tensor_domain = std::vector<double>{0, 2};
177177
}
178178

179-
Input::Input(std::vector<int64_t> shape, std::vector<int64_t> tensor_domain, TensorFormat format) {
179+
Input::Input(std::vector<int64_t> shape, std::vector<double> tensor_domain, TensorFormat format) {
180180
this->opt_shape = shape;
181181
this->min_shape = shape;
182182
this->max_shape = shape;
@@ -195,10 +195,10 @@ Input::Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format) {
195195
this->dtype = dtype;
196196
this->format = format;
197197
this->input_is_dynamic = false;
198-
this->tensor_domain = std::vector<int64_t>{0, 2};
198+
this->tensor_domain = std::vector<double>{0, 2};
199199
}
200200

201-
Input::Input(std::vector<int64_t> shape, DataType dtype, std::vector<int64_t> tensor_domain, TensorFormat format) {
201+
Input::Input(std::vector<int64_t> shape, DataType dtype, std::vector<double> tensor_domain, TensorFormat format) {
202202
this->opt_shape = shape;
203203
this->min_shape = shape;
204204
this->max_shape = shape;
@@ -217,10 +217,10 @@ Input::Input(c10::IntArrayRef shape, TensorFormat format) {
217217
this->dtype = DataType::kUnknown;
218218
this->format = format;
219219
this->input_is_dynamic = false;
220-
this->tensor_domain = std::vector<int64_t>{0, 2};
220+
this->tensor_domain = std::vector<double>{0, 2};
221221
}
222222

223-
Input::Input(c10::IntArrayRef shape, std::vector<int64_t> tensor_domain, TensorFormat format) {
223+
Input::Input(c10::IntArrayRef shape, std::vector<double> tensor_domain, TensorFormat format) {
224224
this->opt_shape = torch_tensorrt::core::util::toVec(shape);
225225
this->min_shape = torch_tensorrt::core::util::toVec(shape);
226226
this->max_shape = torch_tensorrt::core::util::toVec(shape);
@@ -239,10 +239,10 @@ Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat format) {
239239
this->dtype = dtype;
240240
this->format = format;
241241
this->input_is_dynamic = false;
242-
this->tensor_domain = std::vector<int64_t>{0, 2};
242+
this->tensor_domain = std::vector<double>{0, 2};
243243
}
244244

245-
Input::Input(c10::IntArrayRef shape, DataType dtype, std::vector<int64_t> tensor_domain, TensorFormat format) {
245+
Input::Input(c10::IntArrayRef shape, DataType dtype, std::vector<double> tensor_domain, TensorFormat format) {
246246
this->opt_shape = torch_tensorrt::core::util::toVec(shape);
247247
this->min_shape = torch_tensorrt::core::util::toVec(shape);
248248
this->max_shape = torch_tensorrt::core::util::toVec(shape);
@@ -266,14 +266,14 @@ Input::Input(
266266
this->dtype = DataType::kUnknown;
267267
this->format = format;
268268
this->input_is_dynamic = true;
269-
this->tensor_domain = std::vector<int64_t>{0, 2};
269+
this->tensor_domain = std::vector<double>{0, 2};
270270
}
271271

272272
Input::Input(
273273
std::vector<int64_t> min_shape,
274274
std::vector<int64_t> opt_shape,
275275
std::vector<int64_t> max_shape,
276-
std::vector<int64_t> tensor_domain,
276+
std::vector<double> tensor_domain,
277277
TensorFormat format) {
278278
this->opt_shape = opt_shape;
279279
this->min_shape = min_shape;
@@ -300,15 +300,15 @@ Input::Input(
300300
this->dtype = dtype;
301301
this->format = format;
302302
this->input_is_dynamic = true;
303-
this->tensor_domain = std::vector<int64_t>{0, 2};
303+
this->tensor_domain = std::vector<double>{0, 2};
304304
}
305305

306306
Input::Input(
307307
std::vector<int64_t> min_shape,
308308
std::vector<int64_t> opt_shape,
309309
std::vector<int64_t> max_shape,
310310
DataType dtype,
311-
std::vector<int64_t> tensor_domain,
311+
std::vector<double> tensor_domain,
312312
TensorFormat format) {
313313
this->opt_shape = opt_shape;
314314
this->min_shape = min_shape;
@@ -330,14 +330,14 @@ Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArr
330330
this->dtype = DataType::kUnknown;
331331
this->format = format;
332332
this->input_is_dynamic = true;
333-
this->tensor_domain = std::vector<int64_t>{0, 2};
333+
this->tensor_domain = std::vector<double>{0, 2};
334334
}
335335

336336
Input::Input(
337337
c10::IntArrayRef min_shape,
338338
c10::IntArrayRef opt_shape,
339339
c10::IntArrayRef max_shape,
340-
std::vector<int64_t> tensor_domain,
340+
std::vector<double> tensor_domain,
341341
TensorFormat format) {
342342
this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape);
343343
this->min_shape = torch_tensorrt::core::util::toVec(min_shape);
@@ -364,15 +364,15 @@ Input::Input(
364364
this->dtype = dtype;
365365
this->format = format;
366366
this->input_is_dynamic = true;
367-
this->tensor_domain = std::vector<int64_t>{0, 2};
367+
this->tensor_domain = std::vector<double>{0, 2};
368368
}
369369

370370
Input::Input(
371371
c10::IntArrayRef min_shape,
372372
c10::IntArrayRef opt_shape,
373373
c10::IntArrayRef max_shape,
374374
DataType dtype,
375-
std::vector<int64_t> tensor_domain,
375+
std::vector<double> tensor_domain,
376376
TensorFormat format) {
377377
this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape);
378378
this->min_shape = torch_tensorrt::core::util::toVec(min_shape);
@@ -402,7 +402,7 @@ Input::Input(at::Tensor tensor) {
402402
}
403403
this->format = frmt;
404404
this->input_is_dynamic = false;
405-
this->tensor_domain = std::vector<int64_t>{0, 2};
405+
this->tensor_domain = std::vector<double>{0, 2};
406406
}
407407

408408
/* ==========================================*/

py/torch_tensorrt/_Input.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple
2+
from typing import List, Dict, Any, Tuple, Optional
33

44
import torch
55

@@ -38,8 +38,8 @@ class _ShapeMode(Enum):
3838
_enums.TensorFormat.contiguous
3939
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
4040

41-
DOMAIN_OFFSET = 2
42-
low_tensor_domain_incl = 0
41+
DOMAIN_OFFSET = 2.0
42+
low_tensor_domain_incl = 0.0
4343
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
4444

4545
def __init__(self, *args, **kwargs):
@@ -60,8 +60,8 @@ def __init__(self, *args, **kwargs):
6060
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
6161
dtype (torch.dtype or torch_tensorrt.dtype): Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
6262
format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
63-
tensor_domain (Tuple(int, int), optional): The domain of allowed integer values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
64-
Note: Entering one NoneType will set the other bound to the provided value +/- 2; entering two NoneTypes (or not specifying) will set the bound to [0, 2)
63+
tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
64+
Note: Entering "None" (or not specifying) will set the bound to [0, 2)
6565
6666
Examples:
6767
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
@@ -289,7 +289,7 @@ def _parse_format(format: Any) -> _enums.TensorFormat:
289289
)
290290

291291
@staticmethod
292-
def _parse_tensor_domain(domain: Tuple[int, int]) -> Tuple:
292+
def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
293293
"""
294294
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
295295
@@ -300,63 +300,35 @@ def _parse_tensor_domain(domain: Tuple[int, int]) -> Tuple:
300300
A tuple of two int32_t-valid integers
301301
"""
302302
if domain is None:
303-
domain_lo = None
304-
domain_hi = None
303+
result_domain = (
304+
Input.low_tensor_domain_incl,
305+
Input.high_tensor_domain_excl,
306+
)
305307
elif len(domain) == 2:
306308
domain_lo, domain_hi = domain
307-
else:
308-
raise ValueError(
309-
f"Expected 2 values for domain, got {len(domain)}: {domain}"
310-
)
311309

312-
lo_domain_missing = domain_lo is None
313-
hi_domain_missing = domain_hi is None
314-
valid_type_lo = (domain_lo is None) or isinstance(domain_lo, int)
315-
valid_type_hi = (domain_hi is None) or isinstance(domain_hi, int)
310+
# Validate type and provided values for domain
311+
valid_type_lo = isinstance(domain_lo, int) or isinstance(domain_lo, float)
312+
valid_type_hi = isinstance(domain_hi, int) or isinstance(domain_hi, float)
316313

317-
if not valid_type_lo:
318-
raise ValueError(
319-
f"Expected integer value for tensor domain low specifier, got {domain_lo}"
320-
)
321-
elif not valid_type_hi:
322-
raise ValueError(
323-
f"Expected integer value for tensor domain high specifier, got {domain_hi}"
324-
)
314+
if not valid_type_lo:
315+
raise ValueError(
316+
f"Expected value for tensor domain low specifier, got {domain_lo}"
317+
)
318+
elif not valid_type_hi:
319+
raise ValueError(
320+
f"Expected value for tensor domain high specifier, got {domain_hi}"
321+
)
325322

326-
if lo_domain_missing and hi_domain_missing:
327-
result_domain = (
328-
Input.low_tensor_domain_incl,
329-
Input.high_tensor_domain_excl,
330-
)
331-
elif lo_domain_missing:
332-
result_domain = (domain_hi - Input.DOMAIN_OFFSET, domain_hi)
333-
elif hi_domain_missing:
334-
result_domain = (domain_lo, domain_lo + Input.DOMAIN_OFFSET)
335-
else:
336323
if domain_hi <= domain_lo:
337324
raise ValueError(
338325
"Expected provided integer range to have low tensor domain value "
339326
+ f"< high tensor domain value, got invalid range [{domain_lo}, {domain_hi})"
340327
)
341-
result_domain = (domain_lo, domain_hi)
342-
343-
def val_exceeds_int32_t_repr(value: int):
344-
"""Returns true if the input integer value exceeds C++ int32_t repr, else false"""
345-
if not isinstance(value, int):
346-
raise ValueError(f"Expected int input, got {type(value)}")
347-
return (value == 0x80000000) or (abs(value) > 0x80000000)
348-
349-
if isinstance(result_domain[0], int) and val_exceeds_int32_t_repr(
350-
result_domain[0]
351-
):
352-
raise ValueError(
353-
f"Determined low-bound on tensor domain does not fit in a 32-bit signed int"
354-
)
355-
elif isinstance(result_domain[1], int) and val_exceeds_int32_t_repr(
356-
result_domain[1]
357-
):
328+
result_domain = (float(domain_lo), float(domain_hi))
329+
else:
358330
raise ValueError(
359-
f"Determined high-bound on tensor domain does not fit in a 32-bit signed int"
331+
f"Expected 2 values for domain, got {len(domain)}: {domain}"
360332
)
361333

362334
return result_domain

0 commit comments

Comments
 (0)