Skip to content

Commit 621bc67

Browse files
committed
refactor!: Removing deprecated InputRange, op_precision and input_shapes
APIs BREAKING CHANGE: This removes the InputRange Class and op_precision and input shape fields which were deprecated in TRTorch v0.4.0 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 15e6863 commit 621bc67

File tree

5 files changed

+4
-192
lines changed

5 files changed

+4
-192
lines changed

Diff for: cpp/include/trtorch/trtorch.h

-88
Original file line numberDiff line numberDiff line change
@@ -506,64 +506,6 @@ struct TRTORCH_API CompileSpec {
506506
bool explicit_set_dtype;
507507
};
508508

509-
/**
510-
* @brief A struct to hold an input range (used by TensorRT Optimization
511-
* profile)
512-
*
513-
* This struct can either hold a single vector representing an input shape,
514-
* signifying a static input shape or a set of three input shapes representing
515-
* the min, optiminal and max input shapes allowed for the engine.
516-
*/
517-
struct TRTORCH_API InputRange {
518-
/// Minimum acceptable input size into the engine
519-
std::vector<int64_t> min;
520-
/// Optimal input size into the engine (gets best performace)
521-
std::vector<int64_t> opt;
522-
/// Maximum acceptable input size into the engine
523-
std::vector<int64_t> max;
524-
/**
525-
* @brief Construct a new Input Range object for static input size from
526-
* vector
527-
*
528-
* @param opt
529-
*/
530-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
531-
std::vector<int64_t> opt);
532-
/**
533-
* @brief Construct a new Input Range object static input size from
534-
* c10::ArrayRef (the type produced by tensor.sizes())
535-
*
536-
* @param opt
537-
*/
538-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
539-
c10::ArrayRef<int64_t> opt);
540-
/**
541-
* @brief Construct a new Input Range object dynamic input size from vectors
542-
* for min, opt, and max supported sizes
543-
*
544-
* @param min
545-
* @param opt
546-
* @param max
547-
*/
548-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
549-
std::vector<int64_t> min,
550-
std::vector<int64_t> opt,
551-
std::vector<int64_t> max);
552-
/**
553-
* @brief Construct a new Input Range object dynamic input size from
554-
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
555-
* supported sizes
556-
*
557-
* @param min
558-
* @param opt
559-
* @param max
560-
*/
561-
[[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
562-
c10::ArrayRef<int64_t> min,
563-
c10::ArrayRef<int64_t> opt,
564-
c10::ArrayRef<int64_t> max);
565-
};
566-
567509
/**
568510
* @brief A struct to hold fallback info
569511
*/
@@ -596,18 +538,6 @@ struct TRTORCH_API CompileSpec {
596538
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
597539
};
598540

599-
/**
600-
* @brief Construct a new Extra Info object from input ranges.
601-
* Each entry in the vector represents a input and should be provided in call
602-
* order.
603-
*
604-
* Use this constructor if you want to use dynamic shape
605-
*
606-
* @param input_ranges
607-
*/
608-
[[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
609-
std::vector<InputRange> input_ranges)
610-
: input_ranges(std::move(input_ranges)) {}
611541
/**
612542
* @brief Construct a new Extra Info object
613543
* Convienence constructor to set fixed input size from vectors describing
@@ -657,24 +587,6 @@ struct TRTORCH_API CompileSpec {
657587
*/
658588
std::vector<Input> inputs;
659589

660-
/**
661-
* Sizes for inputs to engine, can either be a single size or a range
662-
* defined by Min, Optimal, Max sizes
663-
*
664-
* Order is should match call order
665-
*/
666-
[[deprecated(
667-
"trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
668-
vector<InputRange>
669-
input_ranges;
670-
671-
/**
672-
* Default operating precision for the engine
673-
*/
674-
[[deprecated(
675-
"trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
676-
op_precision = DataType::kFloat;
677-
678590
/**
679591
* @brief The set of precisions TensorRT is allowed to use for kernels during compilation
680592
*

Diff for: cpp/src/compile_spec.cpp

+2-53
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,6 @@ CompileSpec::Device::DeviceType::DeviceType(c10::DeviceType t) {
144144
value = DeviceType::kGPU;
145145
}
146146

147-
CompileSpec::InputRange::InputRange(std::vector<int64_t> opt) {
148-
this->opt = opt;
149-
this->min = opt;
150-
this->max = opt;
151-
}
152-
153-
CompileSpec::InputRange::InputRange(c10::IntArrayRef opt) {
154-
this->opt = core::util::toVec(opt);
155-
this->min = core::util::toVec(opt);
156-
this->max = core::util::toVec(opt);
157-
}
158-
159-
CompileSpec::InputRange::InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
160-
this->opt = opt;
161-
this->min = min;
162-
this->max = max;
163-
}
164-
165-
CompileSpec::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
166-
this->opt = core::util::toVec(opt);
167-
this->min = core::util::toVec(min);
168-
this->max = core::util::toVec(max);
169-
}
170-
171147
CompileSpec::CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
172148
for (auto in : fixed_sizes) {
173149
inputs.push_back(Input(in));
@@ -289,22 +265,10 @@ CompileSpec::Input::Input(
289265

290266
/* ==========================================*/
291267

292-
core::ir::Input to_internal_input(CompileSpec::InputRange& i) {
293-
return core::ir::Input(i.min, i.opt, i.max);
294-
}
295-
296268
core::ir::Input to_internal_input(CompileSpec::Input& i) {
297269
return core::ir::Input(i.min_shape, i.opt_shape, i.max_shape, toTRTDataType(i.dtype), toTRTTensorFormat(i.format));
298270
}
299271

300-
std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::InputRange>& external) {
301-
std::vector<core::ir::Input> internal;
302-
for (auto range : external) {
303-
internal.push_back(to_internal_input(range));
304-
}
305-
return internal;
306-
}
307-
308272
std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::Input>& external) {
309273
std::vector<core::ir::Input> internal;
310274
for (auto range : external) {
@@ -328,24 +292,9 @@ core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device) {
328292

329293
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
330294
core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
331-
if (external.input_ranges.size() > 0 && external.inputs.size() > 0) {
332-
TRTORCH_THROW_ERROR(
333-
"Saw both input specs listed for inputs and input_ranges in CompileSpec. input_ranges is deprecated and will be removed in v0.5.0. Please port forward to using inputs");
334-
} else if (external.input_ranges.size() > 0) {
335-
internal = core::CompileSpec(to_vec_internal_inputs(external.input_ranges));
336-
} else {
337-
TRTORCH_CHECK(external.inputs.size() > 0, "Compilation requires at least one input specification");
338-
internal = core::CompileSpec(to_vec_internal_inputs(external.inputs));
339-
}
340295

341-
if (external.enabled_precisions.size() == 1 &&
342-
toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT &&
343-
toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
344-
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
345-
} else {
346-
for (auto p : external.enabled_precisions) {
347-
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
348-
}
296+
for (auto p : external.enabled_precisions) {
297+
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
349298
}
350299

351300
/* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype for

Diff for: examples/int8/training/vgg16/test_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test(model, dataloader, crit):
8181
import trtorch
8282
# trtorch.logging.set_reportable_log_level(trtorch.logging.Level.Debug)
8383
compile_settings = {
84-
"input_shapes": [[1, 3, 32, 32]],
84+
"inputs": [trtorch.Input([1, 3, 32, 32])],
8585
"op_precision": torch.int8 # Run with FP16
8686
}
8787
new_mod = torch.jit.load('trained_vgg16_qat.jit.pt')

Diff for: py/trtorch/_compile_spec.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,11 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
157157

158158
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
159159
info = trtorch._C.CompileSpec()
160-
if "input_shapes" not in compile_spec and "inputs" not in compile_spec:
160+
if "inputs" not in compile_spec:
161161
raise KeyError(
162162
"Module input definitions are requried to compile module. Provide a list of trtorch.Input keyed to \"inputs\" in the compile spec"
163163
)
164164

165-
if "input_shapes" in compile_spec and "inputs" in compile_spec:
166-
raise KeyError(
167-
"Found both key \"input_shapes\", and \"inputs\" in compile spec, please port forward to using only \"inputs\""
168-
)
169-
170-
if "input_shapes" in compile_spec:
171-
warnings.warn(
172-
"Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0",
173-
DeprecationWarning)
174-
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])
175-
176165
if "inputs" in compile_spec:
177166
info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
178167

@@ -181,12 +170,6 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
181170
"Found both key \"op_precision\", and \"enabled_precisions\" in compile spec, please port forward to using only \"enabled_precisions\""
182171
)
183172

184-
if "op_precision" in compile_spec:
185-
warnings.warn(
186-
"Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0",
187-
DeprecationWarning)
188-
info.enabled_precisions = _parse_enabled_precisions(compile_spec["op_precision"])
189-
190173
if "enabled_precisions" in compile_spec:
191174
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"])
192175
# We want default behavior to match PyTorch, so in the case the user did not explicitly set the dtype for inputs they

Diff for: tests/py/test_api.py

-32
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,6 @@ def setUp(self):
1313
self.traced_model = torch.jit.trace(self.model, [self.input])
1414
self.scripted_model = torch.jit.script(self.model)
1515

16-
def test_compile_traced_deprecated(self):
17-
compile_spec = {
18-
"input_shapes": [self.input.shape],
19-
"device": {
20-
"device_type": trtorch.DeviceType.GPU,
21-
"gpu_id": 0,
22-
"dla_core": 0,
23-
"allow_gpu_fallback": False,
24-
"disable_tf32": False
25-
}
26-
}
27-
28-
trt_mod = trtorch.compile(self.traced_model, compile_spec)
29-
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
30-
self.assertTrue(same < 2e-3)
31-
32-
def test_compile_script_deprecated(self):
33-
compile_spec = {
34-
"input_shapes": [self.input.shape],
35-
"device": {
36-
"device_type": trtorch.DeviceType.GPU,
37-
"gpu_id": 0,
38-
"dla_core": 0,
39-
"allow_gpu_fallback": False,
40-
"disable_tf32": False
41-
}
42-
}
43-
44-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
45-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
46-
self.assertTrue(same < 2e-3)
47-
4816
def test_compile_traced(self):
4917
compile_spec = {
5018
"inputs": [trtorch.Input(self.input.shape, dtype=torch.float, format=torch.contiguous_format)],

0 commit comments

Comments
 (0)