-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Add support for providing input datatypes in TRTorch #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -140,6 +140,7 @@
return info
+
def _parse_input_dtypes(input_dtypes: List) -> List:
parsed_input_dtypes = []
for dtype in input_dtypes:
@@ -155,9 +156,12 @@
elif dtype == torch.bool:
parsed_input_dtypes.append(_types.dtype.bool)
else:
- raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+ raise TypeError(
+ "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+ str(dtype))
return parsed_input_dtypes
+
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -140,6 +140,7 @@
return info
+
def _parse_input_dtypes(input_dtypes: List) -> List:
parsed_input_dtypes = []
for dtype in input_dtypes:
@@ -155,9 +156,12 @@
elif dtype == torch.bool:
parsed_input_dtypes.append(_types.dtype.bool)
else:
- raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+ raise TypeError(
+ "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+ str(dtype))
return parsed_input_dtypes
+
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
We also need |
im also wondering if we should create a unified input type that encapsulates: shape, dtype and format, so something like: struct Input/Var/Placeholder (Use the PyTorch term here) {
shape: InputRange
dtype: DataType::kFloat
format: NCHW
}; Then we can have a constructor which accepts a just an input range to provide backwards compatibility, or something like that, users would provide a list of shapes or a list of these if they want to specify dtypes |
Or do you think it is easier to have separate aligned lists? It might be less involved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -140,6 +140,7 @@
return info
+
def _parse_input_dtypes(input_dtypes: List) -> List:
parsed_input_dtypes = []
for dtype in input_dtypes:
@@ -155,9 +156,12 @@
elif dtype == torch.bool:
parsed_input_dtypes.append(_types.dtype.bool)
else:
- raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+ raise TypeError(
+ "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+ str(dtype))
return parsed_input_dtypes
+
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index 8e64aa4..a243ef7 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -23,11 +23,8 @@ struct InputRange {
// Input(std::vector<int64_t> shape);
// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
// Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat32);
-// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat32);
-// nvinfer1::Dims min;
-// nvinfer1::Dims max;
-// nvinfer1::Dims opt;
-// nvinfer1::DataType dtype;
+// Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType
+// dtype=DataType::kFloat32); nvinfer1::Dims min; nvinfer1::Dims max; nvinfer1::Dims opt; nvinfer1::DataType dtype;
// }
} // namespace ir
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 043ba41..7474702 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -215,14 +215,14 @@ struct TRTORCH_API CompileSpec {
*
* @param opt
*/
- Input(std::vector<int64_t> opt, DataType dtype=DataType::kFloat);
+ Input(std::vector<int64_t> opt, DataType dtype = DataType::kFloat);
/**
* @brief Construct a new Input Range object static input size from
* c10::ArrayRef (the type produced by tensor.sizes())
*
* @param opt
*/
- Input(c10::ArrayRef<int64_t> opt, DataType dtype=DataType::kFloat);
+ Input(c10::ArrayRef<int64_t> opt, DataType dtype = DataType::kFloat);
/**
* @brief Construct a new Input Range object dynamic input size from vectors
* for min, opt, and max supported sizes
@@ -231,7 +231,11 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max, DataType dtype=DataType::kFloat);
+ Input(
+ std::vector<int64_t> min,
+ std::vector<int64_t> opt,
+ std::vector<int64_t> max,
+ DataType dtype = DataType::kFloat);
/**
* @brief Construct a new Input Range object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -241,7 +245,11 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- Input(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max, DataType dtype=DataType::kFloat);
+ Input(
+ c10::ArrayRef<int64_t> min,
+ c10::ArrayRef<int64_t> opt,
+ c10::ArrayRef<int64_t> max,
+ DataType dtype = DataType::kFloat);
};
/**
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/ir/Input.cpp b/tmp/changes.txt
index f1f9f69..8612892 100644
--- a/workspace/core/ir/Input.cpp
+++ b/tmp/changes.txt
@@ -133,11 +133,20 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
- TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+ TRTORCH_CHECK(
+ valid_dtype_format_combo(dtype, format),
+ "Unsupported combination of dtype and tensor format: ("
+ << dtype << ", " << format
+ << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}
-Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
+Input::Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ nvinfer1::DataType dtype,
+ nvinfer1::TensorFormat format) {
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
LOG_WARNING("Verify that this dim size is accepted");
}
@@ -178,7 +187,11 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
- TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+ TRTORCH_CHECK(
+ valid_dtype_format_combo(dtype, format),
+ "Unsupported combination of dtype and tensor format: ("
+ << dtype << ", " << format
+ << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}
@@ -186,7 +199,8 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
if (!input.input_is_dynamic) {
os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
} else {
- os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+ os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt
+ << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
}
return os;
}
diff --git a/workspace/core/conversion/conversion.cpp b/tmp/changes.txt
index 324f3d0..1c2b963 100644
--- a/workspace/core/conversion/conversion.cpp
+++ b/tmp/changes.txt
@@ -125,10 +125,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}
-void AddInputs(
- ConversionCtx* ctx,
- at::ArrayRef<const torch::jit::Value*> inputs,
- std::vector<ir::Input>& input_specs) {
+void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
@@ -164,7 +161,8 @@ void AddInputs(
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger,
- "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
+ "Adding Input " << in->debugName() << " (named: " << name << "): " << spec
+ << " in engine (conversion.AddInputs)");
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
diff --git a/workspace/core/conversion/converters/impl/activation.cpp b/tmp/changes.txt
index a7e2a41..77931e4 100644
--- a/workspace/core/conversion/converters/impl/activation.cpp
+++ b/tmp/changes.txt
@@ -177,8 +177,9 @@ auto acthardtanh TRTORCH_UNUSED =
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
- //REVIEW is this right?
- int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
+ // REVIEW is this right?
+ int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
+ ctx->settings.enabled_precisions.end()
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.cpp b/tmp/changes.txt
index 96a6261..50831a8 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.cpp
+++ b/tmp/changes.txt
@@ -60,7 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
LOG_DEBUG(build_settings);
cfg = builder->createBuilderConfig();
- for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
+ for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
switch (*p) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
@@ -121,7 +121,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
static_cast<int>(settings.device.dla_core) < nbDLACores,
"Configured DLA Core ID: " << settings.device.dla_core
<< " not available. Total number of available DLA Cores: " << nbDLACores);
- TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
+ TRTORCH_CHECK(
+ settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
+ "DLA supports only fp16 or int8 precision");
cfg->setDLACore(settings.device.dla_core);
}
}
diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index d524ded..3d14491 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -1,7 +1,7 @@
#pragma once
-#include <vector>
#include <iostream>
+#include <vector>
#include "NvInfer.h"
namespace trtorch {
@@ -9,10 +9,18 @@ namespace core {
namespace ir {
struct Input {
- //Input(std::vector<int64_t> shape);
- //Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
- Input(std::vector<int64_t> shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
- Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
+ // Input(std::vector<int64_t> shape);
+ // Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
+ Input(
+ std::vector<int64_t> shape,
+ nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+ nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+ nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
friend std::ostream& operator<<(std::ostream& os, const Input& input);
bool input_is_dynamic = false;
diff --git a/workspace/core/conversion/conversion.h b/tmp/changes.txt
index fe79669..253dce7 100644
--- a/workspace/core/conversion/conversion.h
+++ b/tmp/changes.txt
@@ -14,8 +14,7 @@ namespace conversion {
struct ConversionInfo {
std::vector<ir::Input> inputs;
BuilderSettings engine_settings;
- ConversionInfo(std::vector<ir::Input> inputs)
- : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
+ ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};
// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.h b/tmp/changes.txt
index 9aee3b4..0570bf5 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.h
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@
#include <map>
#include <memory>
-#include <unordered_map>
#include <set>
+#include <unordered_map>
#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
diff --git a/workspace/cpp/api/src/compile_spec.cpp b/tmp/changes.txt
index fb54547..7fc08aa 100644
--- a/workspace/cpp/api/src/compile_spec.cpp
+++ b/tmp/changes.txt
@@ -61,8 +61,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) {
CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) {
TRTORCH_CHECK(
- t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported"
- );
+ t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported");
switch (t) {
case at::MemoryFormat::ChannelsLast:
@@ -136,7 +135,12 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
this->input_is_dynamic = false;
}
-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format) {
this->opt_shape = opt_shape;
this->min_shape = min_shape;
this->max_shape = max_shape;
@@ -146,7 +150,12 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
this->input_is_dynamic = true;
}
-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+ c10::IntArrayRef min_shape,
+ c10::IntArrayRef opt_shape,
+ c10::IntArrayRef max_shape,
+ DataType dtype,
+ TensorFormat format) {
this->opt_shape = core::util::toVec(opt_shape);
this->min_shape = core::util::toVec(min_shape);
this->max_shape = core::util::toVec(max_shape);
@@ -184,7 +193,7 @@ std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::Inp
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
- if (external.input_ranges.size() > 0 ) {
+ if (external.input_ranges.size() > 0) {
internal = core::CompileSpec(to_vec_internal_inputs(external.input_ranges));
} else {
TRTORCH_CHECK(external.inputs.size() > 0, "Compilation requires at least one input specification");
@@ -194,7 +203,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
if (external.enabled_precisions.size() <= 1 && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
} else {
- for(auto p : external.enabled_precisions) {
+ for (auto p : external.enabled_precisions) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
}
@@ -237,7 +246,8 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
- if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) {
+ if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
+ internal.convert_info.engine_settings.enabled_precisions.end()) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
internal.convert_info.engine_settings.calibrator = nullptr;
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index cd3abf6..5dd383c 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -10,9 +10,9 @@
#include <cuda_runtime.h>
#include <memory>
+#include <set>
#include <string>
#include <vector>
-#include <set>
// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
@@ -264,7 +264,7 @@ struct TRTORCH_API CompileSpec {
};
class TRTORCH_API TensorFormat {
- public:
+ public:
/**
* Underlying enum class to support the TensorFormat Class
*
@@ -365,7 +365,8 @@ struct TRTORCH_API CompileSpec {
std::vector<int64_t> opt_shape;
/// Maximum acceptable input size into the engine
std::vector<int64_t> max_shape;
- /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable dimensions
+ /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable
+ /// dimensions
std::vector<int64_t> shape;
/// Expected data type for the input
DataType dtype;
@@ -380,7 +381,10 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ std::vector<int64_t> shape,
+ DataType dtype = DataType::kFloat,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for static input size from
* c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -390,7 +394,10 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ c10::ArrayRef<int64_t> shape,
+ DataType dtype = DataType::kFloat,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for a dynamic input size from vectors
* for minimum shape, optimal shape, and max shape supported sizes optional arguments
@@ -402,7 +409,12 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ DataType dtype = DataType::kFloat,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input Range object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -414,9 +426,14 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ c10::ArrayRef<int64_t> min_shape,
+ c10::ArrayRef<int64_t> opt_shape,
+ c10::ArrayRef<int64_t> max_shape,
+ DataType dtype = DataType::kFloat,
+ TensorFormat format = TensorFormat::kContiguous);
- private:
+ private:
bool input_is_dynamic;
};
@@ -441,16 +458,16 @@ struct TRTORCH_API CompileSpec {
*
* @param opt
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(std::vector<int64_t> opt);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ std::vector<int64_t> opt);
/**
* @brief Construct a new Input Range object static input size from
* c10::ArrayRef (the type produced by tensor.sizes())
*
* @param opt
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(c10::ArrayRef<int64_t> opt);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ c10::ArrayRef<int64_t> opt);
/**
* @brief Construct a new Input Range object dynamic input size from vectors
* for min, opt, and max supported sizes
@@ -459,8 +476,10 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ std::vector<int64_t> min,
+ std::vector<int64_t> opt,
+ std::vector<int64_t> max);
/**
* @brief Construct a new Input Range object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -470,8 +489,10 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ c10::ArrayRef<int64_t> min,
+ c10::ArrayRef<int64_t> opt,
+ c10::ArrayRef<int64_t> max);
};
/**
@@ -512,8 +533,9 @@ struct TRTORCH_API CompileSpec {
*
* @param input_ranges
*/
- [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
- CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
+ [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
+ std::vector<InputRange> input_ranges)
+ : input_ranges(std::move(input_ranges)) {}
/**
* @brief Construct a new Extra Info object
* Convienence constructor to set fixed input size from vectors describing
@@ -522,8 +544,8 @@ struct TRTORCH_API CompileSpec {
*
* @param fixed_sizes
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] CompileSpec(
+ std::vector<std::vector<int64_t>> fixed_sizes);
/**
* @brief Construct a new Extra Info object
* Convienence constructor to set fixed input size from c10::ArrayRef's (the
@@ -531,14 +553,14 @@ struct TRTORCH_API CompileSpec {
* the vector represents a input and should be provided in call order.
* @param fixed_sizes
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] CompileSpec(
+ std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
// Defaults should reflect TensorRT defaults for BuilderConfig
/**
- * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max sizes
- * Users can also specify expected input type as well as tensor memory format
+ * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max
+ * sizes Users can also specify expected input type as well as tensor memory format
*
* Order in vector should match call order for the function
*/
@@ -550,14 +572,17 @@ struct TRTORCH_API CompileSpec {
*
* Order is should match call order
*/
- [[deprecated("trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]]
- std::vector<InputRange> input_ranges;
+ [[deprecated(
+ "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
+ vector<InputRange>
+ input_ranges;
/**
* Default operating precision for the engine
*/
- [[deprecated("trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]]
- DataType op_precision = DataType::kFloat;
+ [[deprecated(
+ "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
+ op_precision = DataType::kFloat;
/**
* @brief The set of precisions TensorRT is allowed to use for kernels during compilation
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index 9080f33..3ca490c 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -9,10 +9,9 @@ namespace {
(registry).def("_get_" #field_name, &class_name::get_##field_name);
void RegisterTRTCompileSpec() {
- static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
- torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::Input::to_str);
+ static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::Input::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt);
@@ -21,7 +20,6 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);
-
static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
.def(torch::init<>())
.def("__str__", &trtorch::pyapi::Device::to_str);
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/py/trtorch/Input.py (original)
+++ /workspace/py/trtorch/Input.py (reformatted)
@@ -5,6 +5,7 @@
from trtorch import _types
import trtorch._C
+
class Input(object):
"""
@@ -59,39 +60,45 @@
if len(args) == 1:
if not Input._supported_input_size_type(args[0]):
raise TypeError(
- "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
- + str(type(args[0])))
+ "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ + str(type(args[0])))
if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
- raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
+ raise ValueError(
+ "Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined"
+ )
self.shape = tuple(args[0])
self.shape_mode = Input._ShapeMode.STATIC
elif len(args) == 0:
- if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
- raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined")
+ if not ("shape" in kwargs) and not (all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
+ raise ValueError(
+ "Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined"
+ )
elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
- raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
+ raise ValueError(
+ "Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined"
+ )
if "shape" in kwargs:
if not Input._supported_input_size_type(kwargs["shape"]):
raise TypeError(
- "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
- + str(type(kwargs["shape"])))
+ "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ + str(type(kwargs["shape"])))
self.shape = tuple(kwargs["shape"])
self.shape_mode = Input._ShapeMode.STATIC
else:
if not Input._supported_input_size_type(kwargs["min_shape"]):
raise TypeError(
- "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
- + str(type(kwargs["min_shape"])) + " for min_shape")
+ "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ + str(type(kwargs["min_shape"])) + " for min_shape")
if not Input._supported_input_size_type(kwargs["opt_shape"]):
raise TypeError(
- "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
- + str(type(kwargs["opt_shape"])) + " for opt_shape")
+ "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ + str(type(kwargs["opt_shape"])) + " for opt_shape")
if not Input._supported_input_size_type(kwargs["max_shape"]):
raise TypeError(
- "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
- + str(type(kwargs["max_shape"])) + " for max_shape")
+ "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ + str(type(kwargs["max_shape"])) + " for max_shape")
self.shape = {
"min_shape": tuple(kwargs["min_shape"]),
@@ -101,7 +108,9 @@
self.shape_mode = Input._ShapeMode.DYNAMIC
else:
- raise ValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args)))
+ raise ValueError(
+ "Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments"
+ .format(len(args)))
if "dtype" in kwargs:
self.dtype = Input._parse_dtype(kwargs["dtype"])
@@ -113,7 +122,9 @@
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
- return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format))
+ return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(
+ self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype),
+ str(self.format))
else:
raise RuntimeError("Unknown input shape mode")
@@ -154,8 +165,9 @@
elif dtype == torch.bool:
return _types.dtype.bool
else:
- raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " +
- str(dtype))
+ raise TypeError(
+ "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: "
+ + str(dtype))
elif isinstance(dtype, _types.DataTypes):
return dtype
@@ -172,10 +184,12 @@
elif format == torch.channels_last:
return _types.TensorFormat.channel_last
else:
- raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")
+ raise ValueError(
+ "Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")
elif isinstance(format, _types.TensorFormat):
return format
else:
- raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
+ raise TypeError(
+ "Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -29,7 +29,8 @@
for i in input_sizes:
if isinstance(i, dict):
if all(k in i for k in ["min", "opt", "min"]):
- parsed_input_sizes.append(Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal())
+ parsed_input_sizes.append(
+ Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal())
elif "opt" in i:
parsed_input_sizes.append(Input(shape=i["opt"])._to_internal())
@@ -78,6 +79,7 @@
else:
parsed_precisions.add(_parse_op_precision(precisions))
return parsed_precisions
+
def _parse_device_type(device: Any) -> _types.DeviceType:
if isinstance(device, torch.device):
@@ -139,6 +141,7 @@
return info
+
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info = trtorch._C.CompileSpec()
if "input_shapes" not in compile_spec and "inputs" not in compile_spec:
@@ -152,11 +155,13 @@
)
if "input_shapes" in compile_spec:
- warnings.warn("Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0", DeprecationWarning)
+ warnings.warn(
+ "Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0",
+ DeprecationWarning)
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])
if "inputs" in compile_spec:
- info.inputs = [ i._to_internal() for i in compile_spec["inputs"] ]
+ info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
raise KeyError(
@@ -164,7 +169,9 @@
)
if "op_precision" in compile_spec:
- warnings.warn("Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0", DeprecationWarning)
+ warnings.warn(
+ "Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0",
+ DeprecationWarning)
info.enabled_precisions = _parse_enabled_precisions(compile_spec["op_precision"])
if "enabled_precisions" in compile_spec:
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/Input.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]>
Signed-off-by: Dheeraj Peri <[email protected]>
This commit implements the new input spec type trtorch::core::ir::Input, which incapsulates InputRange and adds the new dtype and tensor format arguments. It also changes DataType op_precision in the engine settings to std::set<nvinfer1::DataType> enabled_precisions, allowing the compiler to set more than a single precision without resorting to catch all rules such as FP32 and Int8 without FP16. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Input type and enabled_precisions set BREAKING CHANGE: This commit introduces the next iteration of the Python TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes" and "op_precision" compile spec keys will be removed. Users should port forward to using the "inputs" key which expects a list of trtorch.Input objects and the "enabled_precisions" key which expects a set of data type specifying enums. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
enabled_types BREAKING CHANGE: This change deprecates InputRange, and the CompileSpec fields "input_shapes", "op_precision" and associated contructors and functions. These are replaced wtih Input, "inputs" and "enabled_precisions" respectively. Deprecated components will be removed in TRTorch v0.5.0 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
behavior unless user explicitly overrides Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
This commits adds tests for the new Input class including verifying that default behavior works properly It also moves tests out module and into cpp for cpp api tests Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
trtorchc Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
…_type Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py (original)
+++ /workspace/tests/py/test_api.py (reformatted)
@@ -73,6 +73,7 @@
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalf(ModelTestCase):
def setUp(self):
@@ -94,6 +95,7 @@
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalfDefault(ModelTestCase):
def setUp(self):
@@ -114,6 +116,7 @@
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestFallbackToTorch(ModelTestCase):
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/ir/Input.cpp b/tmp/changes.txt
index 8841f41..4f3c2c9 100644
--- a/workspace/core/ir/Input.cpp
+++ b/tmp/changes.txt
@@ -131,11 +131,20 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
- TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+ TRTORCH_CHECK(
+ valid_dtype_format_combo(dtype, format),
+ "Unsupported combination of dtype and tensor format: ("
+ << dtype << ", " << format
+ << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}
-Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
+Input::Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ nvinfer1::DataType dtype,
+ nvinfer1::TensorFormat format) {
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
LOG_WARNING("Verify that this dim size is accepted");
}
@@ -174,7 +183,11 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
- TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+ TRTORCH_CHECK(
+ valid_dtype_format_combo(dtype, format),
+ "Unsupported combination of dtype and tensor format: ("
+ << dtype << ", " << format
+ << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
this->format = format;
}
@@ -182,7 +195,8 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
if (!input.input_is_dynamic) {
os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
} else {
- os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+ os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt
+ << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
}
return os;
}
diff --git a/workspace/core/conversion/conversion.cpp b/tmp/changes.txt
index 324f3d0..1c2b963 100644
--- a/workspace/core/conversion/conversion.cpp
+++ b/tmp/changes.txt
@@ -125,10 +125,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
<< "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}
-void AddInputs(
- ConversionCtx* ctx,
- at::ArrayRef<const torch::jit::Value*> inputs,
- std::vector<ir::Input>& input_specs) {
+void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
@@ -164,7 +161,8 @@ void AddInputs(
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger,
- "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
+ "Adding Input " << in->debugName() << " (named: " << name << "): " << spec
+ << " in engine (conversion.AddInputs)");
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
diff --git a/workspace/core/conversion/converters/impl/activation.cpp b/tmp/changes.txt
index a7e2a41..77931e4 100644
--- a/workspace/core/conversion/converters/impl/activation.cpp
+++ b/tmp/changes.txt
@@ -177,8 +177,9 @@ auto acthardtanh TRTORCH_UNUSED =
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
- //REVIEW is this right?
- int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
+ // REVIEW is this right?
+ int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
+ ctx->settings.enabled_precisions.end()
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.cpp b/tmp/changes.txt
index 96a6261..50831a8 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.cpp
+++ b/tmp/changes.txt
@@ -60,7 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
LOG_DEBUG(build_settings);
cfg = builder->createBuilderConfig();
- for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
+ for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
switch (*p) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
@@ -121,7 +121,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
static_cast<int>(settings.device.dla_core) < nbDLACores,
"Configured DLA Core ID: " << settings.device.dla_core
<< " not available. Total number of available DLA Cores: " << nbDLACores);
- TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
+ TRTORCH_CHECK(
+ settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
+ "DLA supports only fp16 or int8 precision");
cfg->setDLACore(settings.device.dla_core);
}
}
diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index d524ded..3d14491 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -1,7 +1,7 @@
#pragma once
-#include <vector>
#include <iostream>
+#include <vector>
#include "NvInfer.h"
namespace trtorch {
@@ -9,10 +9,18 @@ namespace core {
namespace ir {
struct Input {
- //Input(std::vector<int64_t> shape);
- //Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
- Input(std::vector<int64_t> shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
- Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
+ // Input(std::vector<int64_t> shape);
+ // Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
+ Input(
+ std::vector<int64_t> shape,
+ nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+ nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+ nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
friend std::ostream& operator<<(std::ostream& os, const Input& input);
bool input_is_dynamic = false;
diff --git a/workspace/core/conversion/conversion.h b/tmp/changes.txt
index fe79669..253dce7 100644
--- a/workspace/core/conversion/conversion.h
+++ b/tmp/changes.txt
@@ -14,8 +14,7 @@ namespace conversion {
struct ConversionInfo {
std::vector<ir::Input> inputs;
BuilderSettings engine_settings;
- ConversionInfo(std::vector<ir::Input> inputs)
- : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
+ ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};
// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.h b/tmp/changes.txt
index 9aee3b4..0570bf5 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.h
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@
#include <map>
#include <memory>
-#include <unordered_map>
#include <set>
+#include <unordered_map>
#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
diff --git a/workspace/cpp/trtorchc/main.cpp b/tmp/changes.txt
index 794238c..df179ca 100644
--- a/workspace/cpp/trtorchc/main.cpp
+++ b/tmp/changes.txt
@@ -43,8 +43,7 @@ bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
}
trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
- std::transform(
- str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); });
+ std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); });
if (str == "linear" || str == "nchw" || str == "chw" || str == "contiguous") {
return trtorch::CompileSpec::TensorFormat::kContiguous;
@@ -52,8 +51,8 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
return trtorch::CompileSpec::TensorFormat::kChannelsLast;
} else {
trtorch::logging::log(
- trtorch::logging::Level::kERROR,
- "Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]");
+ trtorch::logging::Level::kERROR,
+ "Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]");
return trtorch::CompileSpec::TensorFormat::kUnknown;
}
}
@@ -73,8 +72,8 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
return trtorch::CompileSpec::DataType::kBool;
} else {
trtorch::logging::log(
- trtorch::logging::Level::kERROR,
- "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]");
+ trtorch::logging::Level::kERROR,
+ "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]");
return trtorch::CompileSpec::DataType::kUnknown;
}
}
@@ -190,10 +189,7 @@ int main(int argc, char** argv) {
args::Flag build_debuggable_engine(
parser, "build-debuggable-engine", "Creates a debuggable engine", {"build-debuggable-engine"});
args::Flag use_strict_types(
- parser,
- "use-strict-types",
- "Restrict operating type to only use set operation precision",
- {"use-strict-types"});
+ parser, "use-strict-types", "Restrict operating type to only use set operation precision", {"use-strict-types"});
args::Flag allow_gpu_fallback(
parser,
"allow-gpu-fallback",
@@ -275,7 +271,8 @@ int main(int argc, char** argv) {
}
std::vector<trtorch::CompileSpec::Input> ranges;
- const std::string spec_err_str = "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"\nTo specify input type append an @ followed by the precision\n e.g. \"(3,3,300,300)@f32\"\nTo specify input format append an \% followed by the format [contiguous | channel_last]\n e.g. \"(3,3,300,300)@f32\%channel_last\"";
+ const std::string spec_err_str =
+ "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"\nTo specify input type append an @ followed by the precision\n e.g. \"(3,3,300,300)@f32\"\nTo specify input format append an \% followed by the format [contiguous | channel_last]\n e.g. \"(3,3,300,300)@f32\%channel_last\"";
for (const auto spec : args::get(input_shapes)) {
std::string shapes;
std::string dtype;
@@ -306,13 +303,14 @@ int main(int argc, char** argv) {
ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(shapes), parsed_dtype, parsed_format));
} else if (shapes.rfind("[", 0) == 0) {
auto dyn_shapes = parseDynamicDim(shapes);
- ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format));
+ ranges.push_back(
+ trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format));
} else {
trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
std::cerr << parser;
exit(1);
}
- // THERE IS NO SPEC FOR FORMAT
+ // THERE IS NO SPEC FOR FORMAT
} else {
std::string shapes = spec.substr(0, spec.find('@'));
std::string dtype = spec.substr(spec.find('@') + 1, spec.size());
@@ -334,7 +332,7 @@ int main(int argc, char** argv) {
exit(1);
}
}
- // THERE IS A SPEC FOR FORMAT BUT NOT DTYPE
+ // THERE IS A SPEC FOR FORMAT BUT NOT DTYPE
} else if (spec.find('%') != std::string::npos) {
std::string shapes = spec.substr(0, spec.find('%'));
std::string format = spec.substr(spec.find('%') + 1, spec.size());
@@ -355,7 +353,7 @@ int main(int argc, char** argv) {
std::cerr << parser;
exit(1);
}
- // JUST SHAPE USE DEFAULT DTYPE
+ // JUST SHAPE USE DEFAULT DTYPE
} else {
if (spec.rfind("(", 0) == 0) {
ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(spec)));
@@ -449,7 +447,6 @@ int main(int argc, char** argv) {
}
}
-
if (engine_capability) {
auto capability = args::get(engine_capability);
std::transform(
@@ -510,7 +507,9 @@ int main(int argc, char** argv) {
} else {
auto trt_mod = trtorch::CompileGraph(mod, compile_settings);
- if (compile_settings.enabled_precisions.size() == 1 && compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) != compile_settings.enabled_precisions.end()) {
+ if (compile_settings.enabled_precisions.size() == 1 &&
+ compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) !=
+ compile_settings.enabled_precisions.end()) {
double threshold_val = 2e-5;
if (threshold) {
threshold_val = args::get(threshold);
diff --git a/workspace/cpp/api/src/compile_spec.cpp b/tmp/changes.txt
index 1d55443..5e44c6b 100644
--- a/workspace/cpp/api/src/compile_spec.cpp
+++ b/tmp/changes.txt
@@ -62,14 +62,16 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) {
};
if (!input.input_is_dynamic) {
- os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+ os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format
+ << ')';
} else {
- os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape) << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+ os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape)
+ << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape)
+ << ", dtype: " << input.dtype << ", format: " << input.format << ')';
}
return os;
}
-
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
switch (value) {
case CompileSpec::DataType::kChar:
@@ -122,8 +124,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) {
CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) {
TRTORCH_CHECK(
- t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported"
- );
+ t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported");
switch (t) {
case at::MemoryFormat::ChannelsLast:
@@ -221,7 +222,11 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
this->input_is_dynamic = false;
}
-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format) {
+CompileSpec::Input::Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ TensorFormat format) {
this->opt_shape = opt_shape;
this->min_shape = min_shape;
this->max_shape = max_shape;
@@ -232,7 +237,12 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
this->input_is_dynamic = true;
}
-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format) {
this->opt_shape = opt_shape;
this->min_shape = min_shape;
this->max_shape = max_shape;
@@ -243,7 +253,11 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
this->input_is_dynamic = true;
}
-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, TensorFormat format) {
+CompileSpec::Input::Input(
+ c10::IntArrayRef min_shape,
+ c10::IntArrayRef opt_shape,
+ c10::IntArrayRef max_shape,
+ TensorFormat format) {
this->opt_shape = core::util::toVec(opt_shape);
this->min_shape = core::util::toVec(min_shape);
this->max_shape = core::util::toVec(max_shape);
@@ -254,7 +268,12 @@ CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape
this->input_is_dynamic = true;
}
-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+ c10::IntArrayRef min_shape,
+ c10::IntArrayRef opt_shape,
+ c10::IntArrayRef max_shape,
+ DataType dtype,
+ TensorFormat format) {
this->opt_shape = core::util::toVec(opt_shape);
this->min_shape = core::util::toVec(min_shape);
this->max_shape = core::util::toVec(max_shape);
@@ -306,17 +325,19 @@ core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device) {
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
- if (external.input_ranges.size() > 0 ) {
+ if (external.input_ranges.size() > 0) {
internal = core::CompileSpec(to_vec_internal_inputs(external.input_ranges));
} else {
TRTORCH_CHECK(external.inputs.size() > 0, "Compilation requires at least one input specification");
internal = core::CompileSpec(to_vec_internal_inputs(external.inputs));
}
- if (external.enabled_precisions.size() <= 1 && toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
+ if (external.enabled_precisions.size() <= 1 &&
+ toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT &&
+ toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
} else {
- for(auto p : external.enabled_precisions) {
+ for (auto p : external.enabled_precisions) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
}
@@ -375,7 +396,8 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
internal.convert_info.engine_settings.workspace_size = external.workspace_size;
- if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) {
+ if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
+ internal.convert_info.engine_settings.enabled_precisions.end()) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
internal.convert_info.engine_settings.calibrator = nullptr;
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 8e492e4..e0fd531 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -11,9 +11,9 @@
#include <cuda_runtime.h>
#include <iostream>
#include <memory>
+#include <set>
#include <string>
#include <vector>
-#include <set>
// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
@@ -268,7 +268,7 @@ struct TRTORCH_API CompileSpec {
};
class TRTORCH_API TensorFormat {
- public:
+ public:
/**
* Underlying enum class to support the TensorFormat Class
*
@@ -352,7 +352,6 @@ struct TRTORCH_API CompileSpec {
return value != other;
}
-
private:
friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
Value value;
@@ -373,7 +372,8 @@ struct TRTORCH_API CompileSpec {
std::vector<int64_t> opt_shape;
/// Maximum acceptable input size into the engine
std::vector<int64_t> max_shape;
- /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable dimensions
+ /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable
+ /// dimensions
std::vector<int64_t> shape;
/// Expected data type for the input
DataType dtype;
@@ -390,7 +390,7 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
+ Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for static input size from
@@ -401,7 +401,7 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+ Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for static input size from
@@ -413,7 +413,7 @@ struct TRTORCH_API CompileSpec {
* @param shape Input tensor shape
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
+ Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for static input size from
@@ -424,7 +424,7 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+ Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input Range object dynamic input size from
@@ -437,7 +437,11 @@ struct TRTORCH_API CompileSpec {
* @param max_shape Maximum acceptible shape for input tensor
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input spec object for a dynamic input size from vectors
@@ -450,7 +454,12 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ std::vector<int64_t> min_shape,
+ std::vector<int64_t> opt_shape,
+ std::vector<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input Range object dynamic input size from
@@ -463,7 +472,11 @@ struct TRTORCH_API CompileSpec {
* @param max_shape Maximum acceptible shape for input tensor
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ c10::ArrayRef<int64_t> min_shape,
+ c10::ArrayRef<int64_t> opt_shape,
+ c10::ArrayRef<int64_t> max_shape,
+ TensorFormat format = TensorFormat::kContiguous);
/**
* @brief Construct a new Input Range object dynamic input size from
@@ -476,10 +489,18 @@ struct TRTORCH_API CompileSpec {
* @param dtype Expected data type for the input (Defaults to Float32)
* @param format Expected tensor format for the input (Defaults to contiguous)
*/
- Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+ Input(
+ c10::ArrayRef<int64_t> min_shape,
+ c10::ArrayRef<int64_t> opt_shape,
+ c10::ArrayRef<int64_t> max_shape,
+ DataType dtype,
+ TensorFormat format = TensorFormat::kContiguous);
- bool get_explicit_set_dtype() {return explicit_set_dtype;}
- private:
+ bool get_explicit_set_dtype() {
+ return explicit_set_dtype;
+ }
+
+ private:
friend std::ostream& operator<<(std::ostream& os, const Input& input);
bool input_is_dynamic;
bool explicit_set_dtype;
@@ -506,16 +527,16 @@ struct TRTORCH_API CompileSpec {
*
* @param opt
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(std::vector<int64_t> opt);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ std::vector<int64_t> opt);
/**
* @brief Construct a new Input Range object static input size from
* c10::ArrayRef (the type produced by tensor.sizes())
*
* @param opt
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(c10::ArrayRef<int64_t> opt);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ c10::ArrayRef<int64_t> opt);
/**
* @brief Construct a new Input Range object dynamic input size from vectors
* for min, opt, and max supported sizes
@@ -524,8 +545,10 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ std::vector<int64_t> min,
+ std::vector<int64_t> opt,
+ std::vector<int64_t> max);
/**
* @brief Construct a new Input Range object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -535,8 +558,10 @@ struct TRTORCH_API CompileSpec {
* @param opt
* @param max
*/
- [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
- InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
+ [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+ c10::ArrayRef<int64_t> min,
+ c10::ArrayRef<int64_t> opt,
+ c10::ArrayRef<int64_t> max);
};
/**
@@ -577,8 +602,9 @@ struct TRTORCH_API CompileSpec {
*
* @param input_ranges
*/
- [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
- CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
+ [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
+ std::vector<InputRange> input_ranges)
+ : input_ranges(std::move(input_ranges)) {}
/**
* @brief Construct a new Extra Info object
* Convienence constructor to set fixed input size from vectors describing
@@ -586,7 +612,8 @@ struct TRTORCH_API CompileSpec {
* should be provided in call order.
*
* This constructor should be use as a convience in the case that all inputs are static sized and
- * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
+ * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights,
+ * contiguous)
*
* @param fixed_sizes
*/
@@ -599,7 +626,8 @@ struct TRTORCH_API CompileSpec {
* the vector represents a input and should be provided in call order.
*
* This constructor should be use as a convience in the case that all inputs are static sized and
- * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
+ * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights,
+ * contiguous)
*
* @param fixed_sizes
*/
@@ -619,8 +647,8 @@ struct TRTORCH_API CompileSpec {
// Defaults should reflect TensorRT defaults for BuilderConfig
/**
- * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max sizes
- * Users can also specify expected input type as well as tensor memory format
+ * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max
+ * sizes Users can also specify expected input type as well as tensor memory format
*
* Order in vector should match call order for the function
*/
@@ -632,14 +660,17 @@ struct TRTORCH_API CompileSpec {
*
* Order is should match call order
*/
- [[deprecated("trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]]
- std::vector<InputRange> input_ranges;
+ [[deprecated(
+ "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
+ vector<InputRange>
+ input_ranges;
/**
* Default operating precision for the engine
*/
- [[deprecated("trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]]
- DataType op_precision = DataType::kFloat;
+ [[deprecated(
+ "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
+ op_precision = DataType::kFloat;
/**
* @brief The set of precisions TensorRT is allowed to use for kernels during compilation
diff --git a/workspace/tests/cpp/test_modules_as_engines.cpp b/tmp/changes.txt
index e6a72e3..58e7699 100644
--- a/workspace/tests/cpp/test_modules_as_engines.cpp
+++ b/tmp/changes.txt
@@ -1,5 +1,5 @@
-#include "cpp_api_test.h"
#include "core/runtime/runtime.h"
+#include "cpp_api_test.h"
TEST_P(CppAPITests, ModuleAsEngineIsClose) {
std::vector<at::Tensor> inputs;
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index f033f00..fa814fd 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -13,7 +13,6 @@ TEST_P(CppAPITests, InputsUseDefault) {
auto spec = trtorch::CompileSpec({in});
spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf);
-
mod.to(torch::kHalf);
auto trt_mod = trtorch::CompileGraph(mod, spec);
@@ -26,5 +25,4 @@ TEST_P(CppAPITests, InputsUseDefault) {
INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
- testing::Values(
- PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}})));
+ testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}})));
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index 9080f33..3ca490c 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -9,10 +9,9 @@ namespace {
(registry).def("_get_" #field_name, &class_name::get_##field_name);
void RegisterTRTCompileSpec() {
- static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
- torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::Input::to_str);
+ static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::Input::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt);
@@ -21,7 +20,6 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);
-
static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
.def(torch::init<>())
.def("__str__", &trtorch::pyapi::Device::to_str);
ERROR: Some files do not conform to style guidelines
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py (original)
+++ /workspace/tests/py/test_api.py (reformatted)
@@ -73,6 +73,7 @@
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalf(ModelTestCase):
def setUp(self):
@@ -94,6 +95,7 @@
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalfDefault(ModelTestCase):
def setUp(self):
@@ -114,6 +116,7 @@
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestFallbackToTorch(ModelTestCase):
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines
py/trtorch/_compiler.py
Outdated
@@ -41,6 +43,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri | |||
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU | |||
}, | |||
"op_precision": torch.half, # Operating precision set to FP16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enabled precisions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pending local testing otherwise LGTM with minor comments
core/ir/Input.cpp
Outdated
// } | ||
|
||
// input_shape = util::toDims(dyn_shape); | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this would be part of your clean up.
/// Bool | ||
kBool, | ||
/// Sentinel value | ||
kUnknown |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this used for ?
@@ -415,6 +593,129 @@ struct TRTORCH_API CompileSpec { | |||
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {} | |||
}; | |||
|
|||
/** | |||
* @brief Construct a new Extra Info object from input ranges. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a minor thing. Extra Info -> CompileSpec in the docstring
os << "half"; | ||
break; | ||
case CompileSpec::DataType::kInt: | ||
os << "int"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use "int32" here to be explicit ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanted to use the same term as pytorch and the api at least for the top level, switches to int32 in the core
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are mostly debugging aids for us, I dont think they are used anywhere in the codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think torch has both torch::kInt32 and torch::kInt or at::Int which are aliases
std::ostream& operator<<(std::ostream& os, const CompileSpec::DataType& dtype) { | ||
switch (dtype) { | ||
case CompileSpec::DataType::kChar: | ||
os << "char"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we print INT8 instead ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanted to use the same term as pytorch and the api at least for the top level, switches to int8 in the core
cpp/api/src/compile_spec.cpp
Outdated
TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported"); | ||
TRTORCH_CHECK( | ||
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool, | ||
"Data type is unsupported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add this after Data type is unsupported. "Supported input datatypes include float|half|int8|int32|bool"
cpp/api/src/compile_spec.cpp
Outdated
|
||
CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) { | ||
TRTORCH_CHECK( | ||
t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Supported options include ChannelsLast and Contiguous"
cpp/api/src/compile_spec.cpp
Outdated
if (external.enabled_precisions.size() <= 1 && | ||
toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT && | ||
toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is happening here in these condition ?
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/tests/py/test_api.py (original)
+++ /workspace/tests/py/test_api.py (reformatted)
@@ -73,6 +73,7 @@
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalf(ModelTestCase):
def setUp(self):
@@ -94,6 +95,7 @@
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalfDefault(ModelTestCase):
def setUp(self):
@@ -114,6 +116,7 @@
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestFallbackToTorch(ModelTestCase):
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
feat(//py): add user level device class in py for embed engine
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py (original)
+++ /workspace/tests/py/test_api.py (reformatted)
@@ -73,6 +73,7 @@
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalf(ModelTestCase):
def setUp(self):
@@ -94,6 +95,7 @@
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestCompileHalfDefault(ModelTestCase):
def setUp(self):
@@ -114,6 +116,7 @@
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
self.assertTrue(same < 2e-2)
+
class TestFallbackToTorch(ModelTestCase):
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
@peri044 should be good to go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Description
This feature adds support for providing input datatypes explicitly. Used for BERT and other common cases where we would require non-float inputs (eg: integers for passing shapes etc). Handles case 3 in #412 (comment)
Pending : Integrate this into
trtorchc
Fixes #388
Type of change
Checklist: