Skip to content

feat: Add support for providing input datatypes in TRTorch #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 22, 2021

Conversation

peri044
Copy link
Collaborator

@peri044 peri044 commented Jun 25, 2021

Description

This feature adds support for providing input datatypes explicitly. Used for BERT and other common cases where we would require non-float inputs (eg: integers for passing shapes etc). Handles case 3 in #412 (comment)
Pending : Integrate this into trtorchc

Fixes #388

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@peri044 peri044 requested a review from narendasan June 25, 2021 02:59
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: api [C++] Issues re: C++ API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler labels Jun 25, 2021
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py	(original)
+++ /workspace/py/trtorch/_compile_spec.py	(reformatted)
@@ -140,6 +140,7 @@

    return info

+
def _parse_input_dtypes(input_dtypes: List) -> List:
    parsed_input_dtypes = []
    for dtype in input_dtypes:
@@ -155,9 +156,12 @@
            elif dtype == torch.bool:
                parsed_input_dtypes.append(_types.dtype.bool)
            else:
-                raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+                raise TypeError(
+                    "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+                    str(dtype))

    return parsed_input_dtypes
+

def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
    info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py	(original)
+++ /workspace/py/trtorch/_compile_spec.py	(reformatted)
@@ -140,6 +140,7 @@

    return info

+
def _parse_input_dtypes(input_dtypes: List) -> List:
    parsed_input_dtypes = []
    for dtype in input_dtypes:
@@ -155,9 +156,12 @@
            elif dtype == torch.bool:
                parsed_input_dtypes.append(_types.dtype.bool)
            else:
-                raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+                raise TypeError(
+                    "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+                    str(dtype))

    return parsed_input_dtypes
+

def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
    info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@narendasan
Copy link
Collaborator

We also need trtorchc to be updated

@narendasan
Copy link
Collaborator

narendasan commented Jun 28, 2021

im also wondering if we should create a unified input type that encapsulates: shape, dtype and format, so something like:

struct Input/Var/Placeholder (Use the PyTorch term here) {
    shape: InputRange 
    dtype: DataType::kFloat
    format: NCHW
};

Then we can have a constructor which accepts a just an input range to provide backwards compatibility, or something like that, users would provide a list of shapes or a list of these if they want to specify dtypes

@narendasan
Copy link
Collaborator

Or do you think it is easier to have separate aligned lists? It might be less involved.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/py/trtorch/_compile_spec.py	(original)
+++ /workspace/py/trtorch/_compile_spec.py	(reformatted)
@@ -140,6 +140,7 @@

    return info

+
def _parse_input_dtypes(input_dtypes: List) -> List:
    parsed_input_dtypes = []
    for dtype in input_dtypes:
@@ -155,9 +156,12 @@
            elif dtype == torch.bool:
                parsed_input_dtypes.append(_types.dtype.bool)
            else:
-                raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
+                raise TypeError(
+                    "Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " +
+                    str(dtype))

    return parsed_input_dtypes
+

def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
    info = trtorch._C.CompileSpec()
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index 8e64aa4..a243ef7 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -23,11 +23,8 @@ struct InputRange {
//   Input(std::vector<int64_t> shape);
//   Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
//   Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat32);
-//   Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat32);
-//   nvinfer1::Dims min;
-//   nvinfer1::Dims max;
-//   nvinfer1::Dims opt;
-//   nvinfer1::DataType dtype;
+//   Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType
+//   dtype=DataType::kFloat32); nvinfer1::Dims min; nvinfer1::Dims max; nvinfer1::Dims opt; nvinfer1::DataType dtype;
// }

} // namespace ir
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 043ba41..7474702 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -215,14 +215,14 @@ struct TRTORCH_API CompileSpec {
     *
     * @param opt
     */
-    Input(std::vector<int64_t> opt, DataType dtype=DataType::kFloat);
+    Input(std::vector<int64_t> opt, DataType dtype = DataType::kFloat);
    /**
     * @brief Construct a new Input Range object static input size from
     * c10::ArrayRef (the type produced by tensor.sizes())
     *
     * @param opt
     */
-    Input(c10::ArrayRef<int64_t> opt, DataType dtype=DataType::kFloat);
+    Input(c10::ArrayRef<int64_t> opt, DataType dtype = DataType::kFloat);
    /**
     * @brief Construct a new Input Range object dynamic input size from vectors
     * for min, opt, and max supported sizes
@@ -231,7 +231,11 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    Input(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max, DataType dtype=DataType::kFloat);
+    Input(
+        std::vector<int64_t> min,
+        std::vector<int64_t> opt,
+        std::vector<int64_t> max,
+        DataType dtype = DataType::kFloat);
    /**
     * @brief Construct a new Input Range object dynamic input size from
     * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -241,7 +245,11 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    Input(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max, DataType dtype=DataType::kFloat);
+    Input(
+        c10::ArrayRef<int64_t> min,
+        c10::ArrayRef<int64_t> opt,
+        c10::ArrayRef<int64_t> max,
+        DataType dtype = DataType::kFloat);
  };

  /**
ERROR: Some files do not conform to style guidelines

@narendasan narendasan marked this pull request as draft June 30, 2021 19:58
@narendasan narendasan added the WIP Work is in progress, pull request should not be merged yet label Jun 30, 2021
@github-actions github-actions bot added component: converters Issues re: Specific op converters component: tests Issues re: Tests labels Jul 19, 2021
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/ir/Input.cpp b/tmp/changes.txt
index f1f9f69..8612892 100644
--- a/workspace/core/ir/Input.cpp
+++ b/tmp/changes.txt
@@ -133,11 +133,20 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten

  TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
  this->dtype = dtype;
-  TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+  TRTORCH_CHECK(
+      valid_dtype_format_combo(dtype, format),
+      "Unsupported combination of dtype and tensor format: ("
+          << dtype << ", " << format
+          << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
  this->format = format;
}

-Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
+Input::Input(
+    std::vector<int64_t> min_shape,
+    std::vector<int64_t> opt_shape,
+    std::vector<int64_t> max_shape,
+    nvinfer1::DataType dtype,
+    nvinfer1::TensorFormat format) {
  if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
    LOG_WARNING("Verify that this dim size is accepted");
  }
@@ -178,7 +187,11 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std

  TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
  this->dtype = dtype;
-  TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+  TRTORCH_CHECK(
+      valid_dtype_format_combo(dtype, format),
+      "Unsupported combination of dtype and tensor format: ("
+          << dtype << ", " << format
+          << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
  this->format = format;
}

@@ -186,7 +199,8 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
  if (!input.input_is_dynamic) {
    os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
  } else {
-    os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+    os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt
+       << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
  }
  return os;
}
diff --git a/workspace/core/conversion/conversion.cpp b/tmp/changes.txt
index 324f3d0..1c2b963 100644
--- a/workspace/core/conversion/conversion.cpp
+++ b/tmp/changes.txt
@@ -125,10 +125,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
                       << "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}

-void AddInputs(
-    ConversionCtx* ctx,
-    at::ArrayRef<const torch::jit::Value*> inputs,
-    std::vector<ir::Input>& input_specs) {
+void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
  std::vector<const torch::jit::Value*> input_tensors;
  for (auto in : inputs) {
    // Disregarding inputs that are not tensors
@@ -164,7 +161,8 @@ void AddInputs(
    std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
    LOG_INFO(
        ctx->logger,
-        "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
+        "Adding Input " << in->debugName() << " (named: " << name << "): " << spec
+                        << " in engine (conversion.AddInputs)");

    auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
    TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
diff --git a/workspace/core/conversion/converters/impl/activation.cpp b/tmp/changes.txt
index a7e2a41..77931e4 100644
--- a/workspace/core/conversion/converters/impl/activation.cpp
+++ b/tmp/changes.txt
@@ -177,8 +177,9 @@ auto acthardtanh TRTORCH_UNUSED =
                    std::string pluginName = "CustomGeluPluginDynamic";
                    nvinfer1::PluginFieldCollection fc;
                    std::vector<nvinfer1::PluginField> f;
-                    //REVIEW is this right?
-                    int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
+                    // REVIEW is this right?
+                    int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
+                            ctx->settings.enabled_precisions.end()
                        ? 0
                        : 1; // Integer encoding the DataType (0: FP32, 1: FP16)
                    f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.cpp b/tmp/changes.txt
index 96a6261..50831a8 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.cpp
+++ b/tmp/changes.txt
@@ -60,7 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
  LOG_DEBUG(build_settings);
  cfg = builder->createBuilderConfig();

-  for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
+  for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
    switch (*p) {
      case nvinfer1::DataType::kHALF:
        TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
@@ -121,7 +121,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
        static_cast<int>(settings.device.dla_core) < nbDLACores,
        "Configured DLA Core ID: " << settings.device.dla_core
                                   << " not available. Total number of available DLA Cores: " << nbDLACores);
-    TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
+    TRTORCH_CHECK(
+        settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
+        "DLA supports only fp16 or int8 precision");
    cfg->setDLACore(settings.device.dla_core);
  }
}
diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index d524ded..3d14491 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -1,7 +1,7 @@
#pragma once

-#include <vector>
#include <iostream>
+#include <vector>
#include "NvInfer.h"

namespace trtorch {
@@ -9,10 +9,18 @@ namespace core {
namespace ir {

struct Input {
-  //Input(std::vector<int64_t> shape);
-  //Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
-  Input(std::vector<int64_t> shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
-  Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
+  // Input(std::vector<int64_t> shape);
+  // Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
+  Input(
+      std::vector<int64_t> shape,
+      nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+      nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
+  Input(
+      std::vector<int64_t> min_shape,
+      std::vector<int64_t> opt_shape,
+      std::vector<int64_t> max_shape,
+      nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+      nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
  friend std::ostream& operator<<(std::ostream& os, const Input& input);

  bool input_is_dynamic = false;
diff --git a/workspace/core/conversion/conversion.h b/tmp/changes.txt
index fe79669..253dce7 100644
--- a/workspace/core/conversion/conversion.h
+++ b/tmp/changes.txt
@@ -14,8 +14,7 @@ namespace conversion {
struct ConversionInfo {
  std::vector<ir::Input> inputs;
  BuilderSettings engine_settings;
-  ConversionInfo(std::vector<ir::Input> inputs)
-      : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
+  ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};

// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.h b/tmp/changes.txt
index 9aee3b4..0570bf5 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.h
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@

#include <map>
#include <memory>
-#include <unordered_map>
#include <set>
+#include <unordered_map>

#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
diff --git a/workspace/cpp/api/src/compile_spec.cpp b/tmp/changes.txt
index fb54547..7fc08aa 100644
--- a/workspace/cpp/api/src/compile_spec.cpp
+++ b/tmp/changes.txt
@@ -61,8 +61,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) {

CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) {
  TRTORCH_CHECK(
-    t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported"
-  );
+      t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported");

  switch (t) {
    case at::MemoryFormat::ChannelsLast:
@@ -136,7 +135,12 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
  this->input_is_dynamic = false;
}

-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+    std::vector<int64_t> min_shape,
+    std::vector<int64_t> opt_shape,
+    std::vector<int64_t> max_shape,
+    DataType dtype,
+    TensorFormat format) {
  this->opt_shape = opt_shape;
  this->min_shape = min_shape;
  this->max_shape = max_shape;
@@ -146,7 +150,12 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
  this->input_is_dynamic = true;
}

-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+    c10::IntArrayRef min_shape,
+    c10::IntArrayRef opt_shape,
+    c10::IntArrayRef max_shape,
+    DataType dtype,
+    TensorFormat format) {
  this->opt_shape = core::util::toVec(opt_shape);
  this->min_shape = core::util::toVec(min_shape);
  this->max_shape = core::util::toVec(max_shape);
@@ -184,7 +193,7 @@ std::vector<core::ir::Input> to_vec_internal_inputs(std::vector<CompileSpec::Inp

core::CompileSpec to_internal_compile_spec(CompileSpec external) {
  core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
-  if (external.input_ranges.size() > 0 ) {
+  if (external.input_ranges.size() > 0) {
    internal = core::CompileSpec(to_vec_internal_inputs(external.input_ranges));
  } else {
    TRTORCH_CHECK(external.inputs.size() > 0, "Compilation requires at least one input specification");
@@ -194,7 +203,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
  if (external.enabled_precisions.size() <= 1 && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
    internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
  } else {
-    for(auto p : external.enabled_precisions) {
+    for (auto p : external.enabled_precisions) {
      internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
    }
  }
@@ -237,7 +246,8 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
  internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
  internal.convert_info.engine_settings.workspace_size = external.workspace_size;

-  if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) {
+  if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
+      internal.convert_info.engine_settings.enabled_precisions.end()) {
    internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
  } else {
    internal.convert_info.engine_settings.calibrator = nullptr;
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index cd3abf6..5dd383c 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -10,9 +10,9 @@

#include <cuda_runtime.h>
#include <memory>
+#include <set>
#include <string>
#include <vector>
-#include <set>

// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
@@ -264,7 +264,7 @@ struct TRTORCH_API CompileSpec {
  };

  class TRTORCH_API TensorFormat {
-    public:
+   public:
    /**
     * Underlying enum class to support the TensorFormat Class
     *
@@ -365,7 +365,8 @@ struct TRTORCH_API CompileSpec {
    std::vector<int64_t> opt_shape;
    /// Maximum acceptable input size into the engine
    std::vector<int64_t> max_shape;
-    /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable dimensions
+    /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable
+    /// dimensions
    std::vector<int64_t> shape;
    /// Expected data type for the input
    DataType dtype;
@@ -380,7 +381,10 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        std::vector<int64_t> shape,
+        DataType dtype = DataType::kFloat,
+        TensorFormat format = TensorFormat::kContiguous);
    /**
     * @brief Construct a new Input spec object for static input size from
     * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments
@@ -390,7 +394,10 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        c10::ArrayRef<int64_t> shape,
+        DataType dtype = DataType::kFloat,
+        TensorFormat format = TensorFormat::kContiguous);
    /**
     * @brief Construct a new Input spec object for a dynamic input size from vectors
     * for minimum shape, optimal shape, and max shape supported sizes optional arguments
@@ -402,7 +409,12 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        std::vector<int64_t> min_shape,
+        std::vector<int64_t> opt_shape,
+        std::vector<int64_t> max_shape,
+        DataType dtype = DataType::kFloat,
+        TensorFormat format = TensorFormat::kContiguous);
    /**
     * @brief Construct a new Input Range object dynamic input size from
     * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -414,9 +426,14 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype=DataType::kFloat, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        c10::ArrayRef<int64_t> min_shape,
+        c10::ArrayRef<int64_t> opt_shape,
+        c10::ArrayRef<int64_t> max_shape,
+        DataType dtype = DataType::kFloat,
+        TensorFormat format = TensorFormat::kContiguous);

-  private:
+   private:
    bool input_is_dynamic;
  };

@@ -441,16 +458,16 @@ struct TRTORCH_API CompileSpec {
     *
     * @param opt
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(std::vector<int64_t> opt);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        std::vector<int64_t> opt);
    /**
     * @brief Construct a new Input Range object static input size from
     * c10::ArrayRef (the type produced by tensor.sizes())
     *
     * @param opt
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(c10::ArrayRef<int64_t> opt);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        c10::ArrayRef<int64_t> opt);
    /**
     * @brief Construct a new Input Range object dynamic input size from vectors
     * for min, opt, and max supported sizes
@@ -459,8 +476,10 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        std::vector<int64_t> min,
+        std::vector<int64_t> opt,
+        std::vector<int64_t> max);
    /**
     * @brief Construct a new Input Range object dynamic input size from
     * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -470,8 +489,10 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        c10::ArrayRef<int64_t> min,
+        c10::ArrayRef<int64_t> opt,
+        c10::ArrayRef<int64_t> max);
  };

  /**
@@ -512,8 +533,9 @@ struct TRTORCH_API CompileSpec {
   *
   * @param input_ranges
   */
-  [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
-  CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
+  [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
+      std::vector<InputRange> input_ranges)
+      : input_ranges(std::move(input_ranges)) {}
  /**
   * @brief Construct a new Extra Info object
   * Convienence constructor to set fixed input size from vectors describing
@@ -522,8 +544,8 @@ struct TRTORCH_API CompileSpec {
   *
   * @param fixed_sizes
   */
-  [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-  CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
+  [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] CompileSpec(
+      std::vector<std::vector<int64_t>> fixed_sizes);
  /**
   * @brief Construct a new Extra Info object
   * Convienence constructor to set fixed input size from c10::ArrayRef's (the
@@ -531,14 +553,14 @@ struct TRTORCH_API CompileSpec {
   * the vector represents a input and should be provided in call order.
   * @param fixed_sizes
   */
-  [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-  CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
+  [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] CompileSpec(
+      std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

  // Defaults should reflect TensorRT defaults for BuilderConfig

  /**
-   * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max sizes
-   * Users can also specify expected input type as well as tensor memory format
+   * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max
+   * sizes Users can also specify expected input type as well as tensor memory format
   *
   * Order in vector should match call order for the function
   */
@@ -550,14 +572,17 @@ struct TRTORCH_API CompileSpec {
   *
   * Order is should match call order
   */
-  [[deprecated("trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]]
-  std::vector<InputRange> input_ranges;
+  [[deprecated(
+      "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
+      vector<InputRange>
+          input_ranges;

  /**
   * Default operating precision for the engine
   */
-  [[deprecated("trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]]
-  DataType op_precision = DataType::kFloat;
+  [[deprecated(
+      "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
+      op_precision = DataType::kFloat;

  /**
   * @brief The set of precisions TensorRT is allowed to use for kernels during compilation
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index 9080f33..3ca490c 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -9,10 +9,9 @@ namespace {
  (registry).def("_get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
-  static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
-      torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
-          .def(torch::init<>())
-          .def("__str__", &trtorch::pyapi::Input::to_str);
+  static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
+                                                               .def(torch::init<>())
+                                                               .def("__str__", &trtorch::pyapi::Input::to_str);

  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min);
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt);
@@ -21,7 +20,6 @@ void RegisterTRTCompileSpec() {
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);

-
  static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
                                                           .def(torch::init<>())
                                                           .def("__str__", &trtorch::pyapi::Device::to_str);
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/py/trtorch/Input.py	(original)
+++ /workspace/py/trtorch/Input.py	(reformatted)
@@ -5,6 +5,7 @@

from trtorch import _types
import trtorch._C
+

class Input(object):
    """
@@ -59,39 +60,45 @@
        if len(args) == 1:
            if not Input._supported_input_size_type(args[0]):
                raise TypeError(
-                "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
-                + str(type(args[0])))
+                    "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+                    + str(type(args[0])))
            if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
-                raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
+                raise ValueError(
+                    "Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined"
+                )
            self.shape = tuple(args[0])
            self.shape_mode = Input._ShapeMode.STATIC

        elif len(args) == 0:
-            if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
-                raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined")
+            if not ("shape" in kwargs) and not (all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
+                raise ValueError(
+                    "Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined"
+                )
            elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
-                raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
+                raise ValueError(
+                    "Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined"
+                )

            if "shape" in kwargs:
                if not Input._supported_input_size_type(kwargs["shape"]):
                    raise TypeError(
-                    "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
-                    + str(type(kwargs["shape"])))
+                        "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+                        + str(type(kwargs["shape"])))
                self.shape = tuple(kwargs["shape"])
                self.shape_mode = Input._ShapeMode.STATIC
            else:
                if not Input._supported_input_size_type(kwargs["min_shape"]):
                    raise TypeError(
-                    "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
-                    + str(type(kwargs["min_shape"])) + " for min_shape")
+                        "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+                        + str(type(kwargs["min_shape"])) + " for min_shape")
                if not Input._supported_input_size_type(kwargs["opt_shape"]):
                    raise TypeError(
-                    "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
-                    + str(type(kwargs["opt_shape"])) + " for opt_shape")
+                        "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+                        + str(type(kwargs["opt_shape"])) + " for opt_shape")
                if not Input._supported_input_size_type(kwargs["max_shape"]):
                    raise TypeError(
-                    "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
-                    + str(type(kwargs["max_shape"])) + " for max_shape")
+                        "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+                        + str(type(kwargs["max_shape"])) + " for max_shape")

                self.shape = {
                    "min_shape": tuple(kwargs["min_shape"]),
@@ -101,7 +108,9 @@
                self.shape_mode = Input._ShapeMode.DYNAMIC

        else:
-            raise ValueError("Unexpected number of positional arguments for class Input \n    Found {} arguments, expected either zero or a single positional arguments".format(len(args)))
+            raise ValueError(
+                "Unexpected number of positional arguments for class Input \n    Found {} arguments, expected either zero or a single positional arguments"
+                .format(len(args)))

        if "dtype" in kwargs:
            self.dtype = Input._parse_dtype(kwargs["dtype"])
@@ -113,7 +122,9 @@
        if self.shape_mode == Input._ShapeMode.STATIC:
            return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
        elif self.shape_mode == Input._ShapeMode.DYNAMIC:
-            return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format))
+            return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(
+                self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype),
+                str(self.format))
        else:
            raise RuntimeError("Unknown input shape mode")

@@ -154,8 +165,9 @@
            elif dtype == torch.bool:
                return _types.dtype.bool
            else:
-                raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " +
-                                str(dtype))
+                raise TypeError(
+                    "Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: "
+                    + str(dtype))

        elif isinstance(dtype, _types.DataTypes):
            return dtype
@@ -172,10 +184,12 @@
            elif format == torch.channels_last:
                return _types.TensorFormat.channel_last
            else:
-                raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")
+                raise ValueError(
+                    "Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")

        elif isinstance(format, _types.TensorFormat):
            return format

        else:
-            raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
+            raise TypeError(
+                "Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
--- /workspace/py/trtorch/_compile_spec.py	(original)
+++ /workspace/py/trtorch/_compile_spec.py	(reformatted)
@@ -29,7 +29,8 @@
    for i in input_sizes:
        if isinstance(i, dict):
            if all(k in i for k in ["min", "opt", "min"]):
-                parsed_input_sizes.append(Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal())
+                parsed_input_sizes.append(
+                    Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal())

            elif "opt" in i:
                parsed_input_sizes.append(Input(shape=i["opt"])._to_internal())
@@ -78,6 +79,7 @@
    else:
        parsed_precisions.add(_parse_op_precision(precisions))
    return parsed_precisions
+

def _parse_device_type(device: Any) -> _types.DeviceType:
    if isinstance(device, torch.device):
@@ -139,6 +141,7 @@

    return info

+
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
    info = trtorch._C.CompileSpec()
    if "input_shapes" not in compile_spec and "inputs" not in compile_spec:
@@ -152,11 +155,13 @@
        )

    if "input_shapes" in compile_spec:
-        warnings.warn("Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0", DeprecationWarning)
+        warnings.warn(
+            "Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0",
+            DeprecationWarning)
        info.inputs = _parse_input_ranges(compile_spec["input_shapes"])

    if "inputs" in compile_spec:
-        info.inputs = [ i._to_internal() for i in compile_spec["inputs"] ]
+        info.inputs = [i._to_internal() for i in compile_spec["inputs"]]

    if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
        raise KeyError(
@@ -164,7 +169,9 @@
        )

    if "op_precision" in compile_spec:
-        warnings.warn("Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0", DeprecationWarning)
+        warnings.warn(
+            "Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0",
+            DeprecationWarning)
        info.enabled_precisions = _parse_enabled_precisions(compile_spec["op_precision"])

    if "enabled_precisions" in compile_spec:
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/Input.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines

peri044 and others added 8 commits July 20, 2021 20:27
This commit implements the new input spec type trtorch::core::ir::Input,
which incapsulates InputRange and adds the new dtype and tensor format arguments.

It also changes DataType op_precision in the engine settings to
std::set<nvinfer1::DataType> enabled_precisions, allowing the compiler
to set more than a single precision without resorting to catch all rules
such as FP32 and Int8 without FP16.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Input type and enabled_precisions set

BREAKING CHANGE: This commit introduces the next iteration of the Python
TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes"
and "op_precision" compile spec keys will be removed. Users should port
forward to using the "inputs" key which expects a list of trtorch.Input
objects and the "enabled_precisions" key which expects a set of data
type specifying enums.

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
enabled_types

BREAKING CHANGE: This change deprecates InputRange, and the CompileSpec
fields "input_shapes", "op_precision" and associated contructors and
functions. These are replaced wtih Input, "inputs" and
"enabled_precisions" respectively. Deprecated components will be removed
in TRTorch v0.5.0

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
behavior unless user explicitly overrides

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
This commits adds tests for the new Input class including verifying that
default behavior works properly

It also moves tests out module and into cpp for cpp api tests

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
trtorchc

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@narendasan narendasan removed the WIP Work is in progress, pull request should not be merged yet label Jul 21, 2021
@narendasan narendasan marked this pull request as ready for review July 21, 2021 03:35
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py	(original)
+++ /workspace/tests/py/test_api.py	(reformatted)
@@ -73,6 +73,7 @@
        same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalf(ModelTestCase):

    def setUp(self):
@@ -94,6 +95,7 @@
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalfDefault(ModelTestCase):

    def setUp(self):
@@ -114,6 +116,7 @@
        trt_mod = trtorch.compile(self.scripted_model, compile_spec)
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)
+

class TestFallbackToTorch(ModelTestCase):

Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/ir/Input.cpp b/tmp/changes.txt
index 8841f41..4f3c2c9 100644
--- a/workspace/core/ir/Input.cpp
+++ b/tmp/changes.txt
@@ -131,11 +131,20 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten

  TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
  this->dtype = dtype;
-  TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+  TRTORCH_CHECK(
+      valid_dtype_format_combo(dtype, format),
+      "Unsupported combination of dtype and tensor format: ("
+          << dtype << ", " << format
+          << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
  this->format = format;
}

-Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype, nvinfer1::TensorFormat format) {
+Input::Input(
+    std::vector<int64_t> min_shape,
+    std::vector<int64_t> opt_shape,
+    std::vector<int64_t> max_shape,
+    nvinfer1::DataType dtype,
+    nvinfer1::TensorFormat format) {
  if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
    LOG_WARNING("Verify that this dim size is accepted");
  }
@@ -174,7 +183,11 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std

  TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
  this->dtype = dtype;
-  TRTORCH_CHECK(valid_dtype_format_combo(dtype, format), "Unsupported combination of dtype and tensor format: (" << dtype << ", " << format << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
+  TRTORCH_CHECK(
+      valid_dtype_format_combo(dtype, format),
+      "Unsupported combination of dtype and tensor format: ("
+          << dtype << ", " << format
+          << "), TRTorch only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
  this->format = format;
}

@@ -182,7 +195,8 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
  if (!input.input_is_dynamic) {
    os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
  } else {
-    os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+    os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt
+       << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
  }
  return os;
}
diff --git a/workspace/core/conversion/conversion.cpp b/tmp/changes.txt
index 324f3d0..1c2b963 100644
--- a/workspace/core/conversion/conversion.cpp
+++ b/tmp/changes.txt
@@ -125,10 +125,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
                       << "please report this error to https://www.github.com/NVIDIA/TRTorch/issues");
}

-void AddInputs(
-    ConversionCtx* ctx,
-    at::ArrayRef<const torch::jit::Value*> inputs,
-    std::vector<ir::Input>& input_specs) {
+void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs, std::vector<ir::Input>& input_specs) {
  std::vector<const torch::jit::Value*> input_tensors;
  for (auto in : inputs) {
    // Disregarding inputs that are not tensors
@@ -164,7 +161,8 @@ void AddInputs(
    std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
    LOG_INFO(
        ctx->logger,
-        "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
+        "Adding Input " << in->debugName() << " (named: " << name << "): " << spec
+                        << " in engine (conversion.AddInputs)");

    auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
    TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
diff --git a/workspace/core/conversion/converters/impl/activation.cpp b/tmp/changes.txt
index a7e2a41..77931e4 100644
--- a/workspace/core/conversion/converters/impl/activation.cpp
+++ b/tmp/changes.txt
@@ -177,8 +177,9 @@ auto acthardtanh TRTORCH_UNUSED =
                    std::string pluginName = "CustomGeluPluginDynamic";
                    nvinfer1::PluginFieldCollection fc;
                    std::vector<nvinfer1::PluginField> f;
-                    //REVIEW is this right?
-                    int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
+                    // REVIEW is this right?
+                    int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
+                            ctx->settings.enabled_precisions.end()
                        ? 0
                        : 1; // Integer encoding the DataType (0: FP32, 1: FP16)
                    f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.cpp b/tmp/changes.txt
index 96a6261..50831a8 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.cpp
+++ b/tmp/changes.txt
@@ -60,7 +60,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
  LOG_DEBUG(build_settings);
  cfg = builder->createBuilderConfig();

-  for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
+  for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
    switch (*p) {
      case nvinfer1::DataType::kHALF:
        TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
@@ -121,7 +121,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
        static_cast<int>(settings.device.dla_core) < nbDLACores,
        "Configured DLA Core ID: " << settings.device.dla_core
                                   << " not available. Total number of available DLA Cores: " << nbDLACores);
-    TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
+    TRTORCH_CHECK(
+        settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(),
+        "DLA supports only fp16 or int8 precision");
    cfg->setDLACore(settings.device.dla_core);
  }
}
diff --git a/workspace/core/ir/ir.h b/tmp/changes.txt
index d524ded..3d14491 100644
--- a/workspace/core/ir/ir.h
+++ b/tmp/changes.txt
@@ -1,7 +1,7 @@
#pragma once

-#include <vector>
#include <iostream>
+#include <vector>
#include "NvInfer.h"

namespace trtorch {
@@ -9,10 +9,18 @@ namespace core {
namespace ir {

struct Input {
-  //Input(std::vector<int64_t> shape);
-  //Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
-  Input(std::vector<int64_t> shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
-  Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, nvinfer1::DataType dtype=nvinfer1::DataType::kFLOAT, nvinfer1::TensorFormat format=nvinfer1::TensorFormat::kLINEAR);
+  // Input(std::vector<int64_t> shape);
+  // Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape);
+  Input(
+      std::vector<int64_t> shape,
+      nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+      nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
+  Input(
+      std::vector<int64_t> min_shape,
+      std::vector<int64_t> opt_shape,
+      std::vector<int64_t> max_shape,
+      nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
+      nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR);
  friend std::ostream& operator<<(std::ostream& os, const Input& input);

  bool input_is_dynamic = false;
diff --git a/workspace/core/conversion/conversion.h b/tmp/changes.txt
index fe79669..253dce7 100644
--- a/workspace/core/conversion/conversion.h
+++ b/tmp/changes.txt
@@ -14,8 +14,7 @@ namespace conversion {
struct ConversionInfo {
  std::vector<ir::Input> inputs;
  BuilderSettings engine_settings;
-  ConversionInfo(std::vector<ir::Input> inputs)
-      : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
+  ConversionInfo(std::vector<ir::Input> inputs) : inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};

// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
diff --git a/workspace/core/conversion/conversionctx/ConversionCtx.h b/tmp/changes.txt
index 9aee3b4..0570bf5 100644
--- a/workspace/core/conversion/conversionctx/ConversionCtx.h
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@

#include <map>
#include <memory>
-#include <unordered_map>
#include <set>
+#include <unordered_map>

#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
diff --git a/workspace/cpp/trtorchc/main.cpp b/tmp/changes.txt
index 794238c..df179ca 100644
--- a/workspace/cpp/trtorchc/main.cpp
+++ b/tmp/changes.txt
@@ -43,8 +43,7 @@ bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
}

trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
-  std::transform(
-    str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); });
+  std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); });

  if (str == "linear" || str == "nchw" || str == "chw" || str == "contiguous") {
    return trtorch::CompileSpec::TensorFormat::kContiguous;
@@ -52,8 +51,8 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
    return trtorch::CompileSpec::TensorFormat::kChannelsLast;
  } else {
    trtorch::logging::log(
-          trtorch::logging::Level::kERROR,
-          "Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]");
+        trtorch::logging::Level::kERROR,
+        "Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]");
    return trtorch::CompileSpec::TensorFormat::kUnknown;
  }
}
@@ -73,8 +72,8 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
    return trtorch::CompileSpec::DataType::kBool;
  } else {
    trtorch::logging::log(
-          trtorch::logging::Level::kERROR,
-          "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]");
+        trtorch::logging::Level::kERROR,
+        "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]");
    return trtorch::CompileSpec::DataType::kUnknown;
  }
}
@@ -190,10 +189,7 @@ int main(int argc, char** argv) {
  args::Flag build_debuggable_engine(
      parser, "build-debuggable-engine", "Creates a debuggable engine", {"build-debuggable-engine"});
  args::Flag use_strict_types(
-      parser,
-      "use-strict-types",
-      "Restrict operating type to only use set operation precision",
-      {"use-strict-types"});
+      parser, "use-strict-types", "Restrict operating type to only use set operation precision", {"use-strict-types"});
  args::Flag allow_gpu_fallback(
      parser,
      "allow-gpu-fallback",
@@ -275,7 +271,8 @@ int main(int argc, char** argv) {
  }

  std::vector<trtorch::CompileSpec::Input> ranges;
-  const std::string spec_err_str = "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"\nTo specify input type append an @ followed by the precision\n e.g. \"(3,3,300,300)@f32\"\nTo specify input format append an \% followed by the format [contiguous | channel_last]\n e.g. \"(3,3,300,300)@f32\%channel_last\"";
+  const std::string spec_err_str =
+      "Dimensions should be specified in one of these types \"(N,..,C,H,W)\" \"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]\"\n e.g \"(3,3,300,300)\" \"[(3,3,100,100);(3,3,200,200);(3,3,300,300)]\"\nTo specify input type append an @ followed by the precision\n e.g. \"(3,3,300,300)@f32\"\nTo specify input format append an \% followed by the format [contiguous | channel_last]\n e.g. \"(3,3,300,300)@f32\%channel_last\"";
  for (const auto spec : args::get(input_shapes)) {
    std::string shapes;
    std::string dtype;
@@ -306,13 +303,14 @@ int main(int argc, char** argv) {
          ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(shapes), parsed_dtype, parsed_format));
        } else if (shapes.rfind("[", 0) == 0) {
          auto dyn_shapes = parseDynamicDim(shapes);
-          ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format));
+          ranges.push_back(
+              trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format));
        } else {
          trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
          std::cerr << parser;
          exit(1);
        }
-      // THERE IS NO SPEC FOR FORMAT
+        // THERE IS NO SPEC FOR FORMAT
      } else {
        std::string shapes = spec.substr(0, spec.find('@'));
        std::string dtype = spec.substr(spec.find('@') + 1, spec.size());
@@ -334,7 +332,7 @@ int main(int argc, char** argv) {
          exit(1);
        }
      }
-    // THERE IS A SPEC FOR FORMAT BUT NOT DTYPE
+      // THERE IS A SPEC FOR FORMAT BUT NOT DTYPE
    } else if (spec.find('%') != std::string::npos) {
      std::string shapes = spec.substr(0, spec.find('%'));
      std::string format = spec.substr(spec.find('%') + 1, spec.size());
@@ -355,7 +353,7 @@ int main(int argc, char** argv) {
        std::cerr << parser;
        exit(1);
      }
-    // JUST SHAPE USE DEFAULT DTYPE
+      // JUST SHAPE USE DEFAULT DTYPE
    } else {
      if (spec.rfind("(", 0) == 0) {
        ranges.push_back(trtorch::CompileSpec::Input(parseSingleDim(spec)));
@@ -449,7 +447,6 @@ int main(int argc, char** argv) {
    }
  }

-
  if (engine_capability) {
    auto capability = args::get(engine_capability);
    std::transform(
@@ -510,7 +507,9 @@ int main(int argc, char** argv) {
  } else {
    auto trt_mod = trtorch::CompileGraph(mod, compile_settings);

-    if (compile_settings.enabled_precisions.size() == 1 && compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) != compile_settings.enabled_precisions.end()) {
+    if (compile_settings.enabled_precisions.size() == 1 &&
+        compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) !=
+            compile_settings.enabled_precisions.end()) {
      double threshold_val = 2e-5;
      if (threshold) {
        threshold_val = args::get(threshold);
diff --git a/workspace/cpp/api/src/compile_spec.cpp b/tmp/changes.txt
index 1d55443..5e44c6b 100644
--- a/workspace/cpp/api/src/compile_spec.cpp
+++ b/tmp/changes.txt
@@ -62,14 +62,16 @@ std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) {
  };

  if (!input.input_is_dynamic) {
-    os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+    os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format
+       << ')';
  } else {
-    os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape) << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
+    os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape)
+       << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape)
+       << ", dtype: " << input.dtype << ", format: " << input.format << ')';
  }
  return os;
}

-
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
  switch (value) {
    case CompileSpec::DataType::kChar:
@@ -122,8 +124,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) {

CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) {
  TRTORCH_CHECK(
-    t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported"
-  );
+      t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported");

  switch (t) {
    case at::MemoryFormat::ChannelsLast:
@@ -221,7 +222,11 @@ CompileSpec::Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat f
  this->input_is_dynamic = false;
}

-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format) {
+CompileSpec::Input::Input(
+    std::vector<int64_t> min_shape,
+    std::vector<int64_t> opt_shape,
+    std::vector<int64_t> max_shape,
+    TensorFormat format) {
  this->opt_shape = opt_shape;
  this->min_shape = min_shape;
  this->max_shape = max_shape;
@@ -232,7 +237,12 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
  this->input_is_dynamic = true;
}

-CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+    std::vector<int64_t> min_shape,
+    std::vector<int64_t> opt_shape,
+    std::vector<int64_t> max_shape,
+    DataType dtype,
+    TensorFormat format) {
  this->opt_shape = opt_shape;
  this->min_shape = min_shape;
  this->max_shape = max_shape;
@@ -243,7 +253,11 @@ CompileSpec::Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> o
  this->input_is_dynamic = true;
}

-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, TensorFormat format) {
+CompileSpec::Input::Input(
+    c10::IntArrayRef min_shape,
+    c10::IntArrayRef opt_shape,
+    c10::IntArrayRef max_shape,
+    TensorFormat format) {
  this->opt_shape = core::util::toVec(opt_shape);
  this->min_shape = core::util::toVec(min_shape);
  this->max_shape = core::util::toVec(max_shape);
@@ -254,7 +268,12 @@ CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape
  this->input_is_dynamic = true;
}

-CompileSpec::Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, DataType dtype, TensorFormat format) {
+CompileSpec::Input::Input(
+    c10::IntArrayRef min_shape,
+    c10::IntArrayRef opt_shape,
+    c10::IntArrayRef max_shape,
+    DataType dtype,
+    TensorFormat format) {
  this->opt_shape = core::util::toVec(opt_shape);
  this->min_shape = core::util::toVec(min_shape);
  this->max_shape = core::util::toVec(max_shape);
@@ -306,17 +325,19 @@ core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device) {

core::CompileSpec to_internal_compile_spec(CompileSpec external) {
  core::CompileSpec internal(to_vec_internal_inputs(external.inputs));
-  if (external.input_ranges.size() > 0 ) {
+  if (external.input_ranges.size() > 0) {
    internal = core::CompileSpec(to_vec_internal_inputs(external.input_ranges));
  } else {
    TRTORCH_CHECK(external.inputs.size() > 0, "Compilation requires at least one input specification");
    internal = core::CompileSpec(to_vec_internal_inputs(external.inputs));
  }

-  if (external.enabled_precisions.size() <= 1 && toTRTDataType(*external.enabled_precisions.begin()) ==  nvinfer1::DataType::kFLOAT && toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
+  if (external.enabled_precisions.size() <= 1 &&
+      toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT &&
+      toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
    internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(external.op_precision));
  } else {
-    for(auto p : external.enabled_precisions) {
+    for (auto p : external.enabled_precisions) {
      internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
    }
  }
@@ -375,7 +396,8 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
  internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters;
  internal.convert_info.engine_settings.workspace_size = external.workspace_size;

-  if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) {
+  if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
+      internal.convert_info.engine_settings.enabled_precisions.end()) {
    internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
  } else {
    internal.convert_info.engine_settings.calibrator = nullptr;
diff --git a/workspace/cpp/api/include/trtorch/trtorch.h b/tmp/changes.txt
index 8e492e4..e0fd531 100644
--- a/workspace/cpp/api/include/trtorch/trtorch.h
+++ b/tmp/changes.txt
@@ -11,9 +11,9 @@
#include <cuda_runtime.h>
#include <iostream>
#include <memory>
+#include <set>
#include <string>
#include <vector>
-#include <set>

// Just include the .h?
#ifndef DOXYGEN_SHOULD_SKIP_THIS
@@ -268,7 +268,7 @@ struct TRTORCH_API CompileSpec {
  };

  class TRTORCH_API TensorFormat {
-    public:
+   public:
    /**
     * Underlying enum class to support the TensorFormat Class
     *
@@ -352,7 +352,6 @@ struct TRTORCH_API CompileSpec {
      return value != other;
    }

-
   private:
    friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
    Value value;
@@ -373,7 +372,8 @@ struct TRTORCH_API CompileSpec {
    std::vector<int64_t> opt_shape;
    /// Maximum acceptable input size into the engine
    std::vector<int64_t> max_shape;
-    /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable dimensions
+    /// Input shape to be fed to TensorRT, in the event of a dynamic shape, -1's will hold the place of variable
+    /// dimensions
    std::vector<int64_t> shape;
    /// Expected data type for the input
    DataType dtype;
@@ -390,7 +390,7 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
+    Input(std::vector<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input spec object for static input size from
@@ -401,7 +401,7 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+    Input(std::vector<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input spec object for static input size from
@@ -413,7 +413,7 @@ struct TRTORCH_API CompileSpec {
     * @param shape Input tensor shape
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> shape, TensorFormat format=TensorFormat::kContiguous);
+    Input(c10::ArrayRef<int64_t> shape, TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input spec object for static input size from
@@ -424,7 +424,7 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+    Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input Range object dynamic input size from
@@ -437,7 +437,11 @@ struct TRTORCH_API CompileSpec {
     * @param max_shape Maximum acceptible shape for input tensor
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        std::vector<int64_t> min_shape,
+        std::vector<int64_t> opt_shape,
+        std::vector<int64_t> max_shape,
+        TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input spec object for a dynamic input size from vectors
@@ -450,7 +454,12 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std::vector<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        std::vector<int64_t> min_shape,
+        std::vector<int64_t> opt_shape,
+        std::vector<int64_t> max_shape,
+        DataType dtype,
+        TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input Range object dynamic input size from
@@ -463,7 +472,11 @@ struct TRTORCH_API CompileSpec {
     * @param max_shape Maximum acceptible shape for input tensor
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        c10::ArrayRef<int64_t> min_shape,
+        c10::ArrayRef<int64_t> opt_shape,
+        c10::ArrayRef<int64_t> max_shape,
+        TensorFormat format = TensorFormat::kContiguous);

    /**
     * @brief Construct a new Input Range object dynamic input size from
@@ -476,10 +489,18 @@ struct TRTORCH_API CompileSpec {
     * @param dtype Expected data type for the input (Defaults to Float32)
     * @param format Expected tensor format for the input (Defaults to contiguous)
     */
-    Input(c10::ArrayRef<int64_t> min_shape, c10::ArrayRef<int64_t> opt_shape, c10::ArrayRef<int64_t> max_shape, DataType dtype, TensorFormat format=TensorFormat::kContiguous);
+    Input(
+        c10::ArrayRef<int64_t> min_shape,
+        c10::ArrayRef<int64_t> opt_shape,
+        c10::ArrayRef<int64_t> max_shape,
+        DataType dtype,
+        TensorFormat format = TensorFormat::kContiguous);

-    bool get_explicit_set_dtype() {return explicit_set_dtype;}
-  private:
+    bool get_explicit_set_dtype() {
+      return explicit_set_dtype;
+    }
+
+   private:
    friend std::ostream& operator<<(std::ostream& os, const Input& input);
    bool input_is_dynamic;
    bool explicit_set_dtype;
@@ -506,16 +527,16 @@ struct TRTORCH_API CompileSpec {
     *
     * @param opt
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(std::vector<int64_t> opt);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        std::vector<int64_t> opt);
    /**
     * @brief Construct a new Input Range object static input size from
     * c10::ArrayRef (the type produced by tensor.sizes())
     *
     * @param opt
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(c10::ArrayRef<int64_t> opt);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        c10::ArrayRef<int64_t> opt);
    /**
     * @brief Construct a new Input Range object dynamic input size from vectors
     * for min, opt, and max supported sizes
@@ -524,8 +545,10 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        std::vector<int64_t> min,
+        std::vector<int64_t> opt,
+        std::vector<int64_t> max);
    /**
     * @brief Construct a new Input Range object dynamic input size from
     * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
@@ -535,8 +558,10 @@ struct TRTORCH_API CompileSpec {
     * @param opt
     * @param max
     */
-    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]]
-    InputRange(c10::ArrayRef<int64_t> min, c10::ArrayRef<int64_t> opt, c10::ArrayRef<int64_t> max);
+    [[deprecated("trtorch::CompileSpec::InputRange is being deprecated in favor of trtorch::CompileSpec::Input. trtorch::CompileSpec::InputRange will be removed in TRTorch v0.5.0")]] InputRange(
+        c10::ArrayRef<int64_t> min,
+        c10::ArrayRef<int64_t> opt,
+        c10::ArrayRef<int64_t> max);
  };

  /**
@@ -577,8 +602,9 @@ struct TRTORCH_API CompileSpec {
   *
   * @param input_ranges
   */
-  [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]]
-  CompileSpec(std::vector<InputRange> input_ranges) : input_ranges(std::move(input_ranges)) {}
+  [[deprecated("trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) is being deprecated in favor of trtorch::CompileSpec::CompileSpec(std::vector<Input> inputs). Please use CompileSpec(std::vector<Input> inputs). trtorch::CompileSpec::CompileSpec(std::vector<InputRange> input_ranges) will be removed in TRTorch v0.5.0")]] CompileSpec(
+      std::vector<InputRange> input_ranges)
+      : input_ranges(std::move(input_ranges)) {}
  /**
   * @brief Construct a new Extra Info object
   * Convienence constructor to set fixed input size from vectors describing
@@ -586,7 +612,8 @@ struct TRTORCH_API CompileSpec {
   * should be provided in call order.
   *
   * This constructor should be use as a convience in the case that all inputs are static sized and
-   * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
+   * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights,
+   * contiguous)
   *
   * @param fixed_sizes
   */
@@ -599,7 +626,8 @@ struct TRTORCH_API CompileSpec {
   * the vector represents a input and should be provided in call order.
   *
   * This constructor should be use as a convience in the case that all inputs are static sized and
-   * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights, contiguous)
+   * you are okay with default input dtype and formats (FP32 for FP32 and INT8 weights, FP16 for FP16 weights,
+   * contiguous)
   *
   * @param fixed_sizes
   */
@@ -619,8 +647,8 @@ struct TRTORCH_API CompileSpec {
  // Defaults should reflect TensorRT defaults for BuilderConfig

  /**
-   * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max sizes
-   * Users can also specify expected input type as well as tensor memory format
+   * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max
+   * sizes Users can also specify expected input type as well as tensor memory format
   *
   * Order in vector should match call order for the function
   */
@@ -632,14 +660,17 @@ struct TRTORCH_API CompileSpec {
   *
   * Order is should match call order
   */
-  [[deprecated("trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]]
-  std::vector<InputRange> input_ranges;
+  [[deprecated(
+      "trtorch::CompileSpec::input_ranges is being deprecated in favor of trtorch::CompileSpec::inputs. trtorch::CompileSpec::input_ranges will be removed in TRTorch v0.5.0")]] std::
+      vector<InputRange>
+          input_ranges;

  /**
   * Default operating precision for the engine
   */
-  [[deprecated("trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]]
-  DataType op_precision = DataType::kFloat;
+  [[deprecated(
+      "trtorch::CompileSpec::op_precision is being deprecated in favor of trtorch::CompileSpec::enabled_precisions, a set of all enabled precisions to use during compilation, trtorch::CompileSpec::op_precision will be removed in TRTorch v0.5.0")]] DataType
+      op_precision = DataType::kFloat;

  /**
   * @brief The set of precisions TensorRT is allowed to use for kernels during compilation
diff --git a/workspace/tests/cpp/test_modules_as_engines.cpp b/tmp/changes.txt
index e6a72e3..58e7699 100644
--- a/workspace/tests/cpp/test_modules_as_engines.cpp
+++ b/tmp/changes.txt
@@ -1,5 +1,5 @@
-#include "cpp_api_test.h"
#include "core/runtime/runtime.h"
+#include "cpp_api_test.h"

TEST_P(CppAPITests, ModuleAsEngineIsClose) {
  std::vector<at::Tensor> inputs;
diff --git a/workspace/tests/cpp/test_default_input_types.cpp b/tmp/changes.txt
index f033f00..fa814fd 100644
--- a/workspace/tests/cpp/test_default_input_types.cpp
+++ b/tmp/changes.txt
@@ -13,7 +13,6 @@ TEST_P(CppAPITests, InputsUseDefault) {
  auto spec = trtorch::CompileSpec({in});
  spec.enabled_precisions.insert(trtorch::CompileSpec::DataType::kHalf);

-
  mod.to(torch::kHalf);

  auto trt_mod = trtorch::CompileGraph(mod, spec);
@@ -26,5 +25,4 @@ TEST_P(CppAPITests, InputsUseDefault) {
INSTANTIATE_TEST_SUITE_P(
    CompiledModuleForwardIsCloseSuite,
    CppAPITests,
-    testing::Values(
-        PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}})));
+    testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}})));
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index 9080f33..3ca490c 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -9,10 +9,9 @@ namespace {
  (registry).def("_get_" #field_name, &class_name::get_##field_name);

void RegisterTRTCompileSpec() {
-  static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
-      torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
-          .def(torch::init<>())
-          .def("__str__", &trtorch::pyapi::Input::to_str);
+  static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
+                                                               .def(torch::init<>())
+                                                               .def("__str__", &trtorch::pyapi::Input::to_str);

  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min);
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt);
@@ -21,7 +20,6 @@ void RegisterTRTCompileSpec() {
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
  ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);

-
  static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
                                                           .def(torch::init<>())
                                                           .def("__str__", &trtorch::pyapi::Device::to_str);
ERROR: Some files do not conform to style guidelines

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py	(original)
+++ /workspace/tests/py/test_api.py	(reformatted)
@@ -73,6 +73,7 @@
        same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalf(ModelTestCase):

    def setUp(self):
@@ -94,6 +95,7 @@
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalfDefault(ModelTestCase):

    def setUp(self):
@@ -114,6 +116,7 @@
        trt_mod = trtorch.compile(self.scripted_model, compile_spec)
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)
+

class TestFallbackToTorch(ModelTestCase):

Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines

@@ -41,6 +43,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
},
"op_precision": torch.half, # Operating precision set to FP16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enabled precisions

Copy link
Collaborator Author

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending local testing otherwise LGTM with minor comments

// }

// input_shape = util::toDims(dyn_shape);
// }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this would be part of your clean up.

/// Bool
kBool,
/// Sentinel value
kUnknown
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this used for ?

@@ -415,6 +593,129 @@ struct TRTORCH_API CompileSpec {
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
};

/**
* @brief Construct a new Extra Info object from input ranges.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a minor thing. Extra Info -> CompileSpec in the docstring

os << "half";
break;
case CompileSpec::DataType::kInt:
os << "int";
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use "int32" here to be explicit ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanted to use the same term as pytorch and the api at least for the top level, switches to int32 in the core

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are mostly debugging aids for us, I dont think they are used anywhere in the codebase

Copy link
Collaborator Author

@peri044 peri044 Jul 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think torch has both torch::kInt32 and torch::kInt or at::Int which are aliases

std::ostream& operator<<(std::ostream& os, const CompileSpec::DataType& dtype) {
switch (dtype) {
case CompileSpec::DataType::kChar:
os << "char";
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we print INT8 instead ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanted to use the same term as pytorch and the api at least for the top level, switches to int8 in the core

TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported");
TRTORCH_CHECK(
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool,
"Data type is unsupported");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add this after Data type is unsupported. "Supported input datatypes include float|half|int8|int32|bool"


CompileSpec::TensorFormat::TensorFormat(at::MemoryFormat t) {
TRTORCH_CHECK(
t == at::MemoryFormat::Contiguous || t == at::MemoryFormat::ChannelsLast, "Tensor format is unsupported");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Supported options include ChannelsLast and Contiguous"

Comment on lines 335 to 337
if (external.enabled_precisions.size() <= 1 &&
toTRTDataType(*external.enabled_precisions.begin()) == nvinfer1::DataType::kFLOAT &&
toTRTDataType(external.op_precision) != nvinfer1::DataType::kFLOAT) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is happening here in these condition ?

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jul 22, 2021
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/tests/py/test_api.py	(original)
+++ /workspace/tests/py/test_api.py	(reformatted)
@@ -73,6 +73,7 @@
        same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalf(ModelTestCase):

    def setUp(self):
@@ -94,6 +95,7 @@
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalfDefault(ModelTestCase):

    def setUp(self):
@@ -114,6 +116,7 @@
        trt_mod = trtorch.compile(self.scripted_model, compile_spec)
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)
+

class TestFallbackToTorch(ModelTestCase):

Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

feat(//py): add user level device class in py for embed engine
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
--- /workspace/tests/py/test_api.py	(original)
+++ /workspace/tests/py/test_api.py	(reformatted)
@@ -73,6 +73,7 @@
        same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalf(ModelTestCase):

    def setUp(self):
@@ -94,6 +95,7 @@
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)

+
class TestCompileHalfDefault(ModelTestCase):

    def setUp(self):
@@ -114,6 +116,7 @@
        trt_mod = trtorch.compile(self.scripted_model, compile_spec)
        same = (trt_mod(self.input.half()) - self.scripted_model(self.input.half())).abs().max()
        self.assertTrue(same < 2e-2)
+

class TestFallbackToTorch(ModelTestCase):

Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@narendasan
Copy link
Collaborator

@peri044 should be good to go

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan narendasan merged commit b9b0aff into master Jul 22, 2021
@narendasan narendasan deleted the input_type branch July 22, 2021 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: api [C++] Issues re: C++ API component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Setting the input data type of models, such as INT32, is not supported
2 participants