-
Notifications
You must be signed in to change notification settings - Fork 365
feat: rewriting param to a Constant if it's a introduced input #1298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Bo Wang <[email protected]>
@bowang007 seems like linting is failing, can you set up the pre-commit system? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test case we can write for this?
core/compiler.cpp
Outdated
@@ -434,6 +451,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) | |||
(!(cfg.lower_info.forced_fallback_modules.size() == 0 && | |||
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || | |||
outputIsCollection)) { | |||
if (!static_params.empty()) { | |||
RewriteInputsWithParams(g, params); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a lowering pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be a lowering pass. However, I'm not so sure about what happens with conversion stage if the input are stored in params. Any thought on this? @narendasan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point static params is usually empty, it is just legacy behavior we still exploit for test so it should be able to handle it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool let me change it to a lowering pass.
Signed-off-by: Bo Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/select.cpp b/tmp/changes.txt
index 20a03f6..6c8f5dd 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/select.cpp
+++ b/tmp/changes.txt
@@ -484,7 +484,7 @@ auto select_registrations TORCHTRT_UNUSED =
auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);
-
+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value");
layer->setName(util::node_info(n).c_str());
@@ -503,7 +503,7 @@ auto select_registrations TORCHTRT_UNUSED =
auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);
-
+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src");
layer->setName(util::node_info(n).c_str());
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/element_wise.cpp b/tmp/changes.txt
index 0ad4c12..32c7050 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/element_wise.cpp
+++ b/tmp/changes.txt
@@ -25,8 +25,6 @@ nvinfer1::ITensor* clamp_util(
return clamp_layer_out;
}
-
-
auto element_wise_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Signed-off-by: Bo Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/core/lowering/passes/rewrite_inputs_with_params.cpp b/tmp/changes.txt
index a05a905..c06fcd6 100644
--- a/home/runner/work/TensorRT/TensorRT/core/lowering/passes/rewrite_inputs_with_params.cpp
+++ b/tmp/changes.txt
@@ -1,13 +1,11 @@
-#include "torch/csrc/jit/ir/constants.h"
#include "core/util/prelude.h"
-
+#include "torch/csrc/jit/ir/constants.h"
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
-
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params) {
auto input_size = g->inputs().size();
auto param_it = params.rbegin();
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/select.cpp b/tmp/changes.txt
index 20a03f6..6c8f5dd 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/select.cpp
+++ b/tmp/changes.txt
@@ -484,7 +484,7 @@ auto select_registrations TORCHTRT_UNUSED =
auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);
-
+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value");
layer->setName(util::node_info(n).c_str());
@@ -503,7 +503,7 @@ auto select_registrations TORCHTRT_UNUSED =
auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);
-
+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src");
layer->setName(util::node_info(n).c_str());
diff --git a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/element_wise.cpp b/tmp/changes.txt
index 0ad4c12..32c7050 100644
--- a/home/runner/work/TensorRT/TensorRT/core/conversion/converters/impl/element_wise.cpp
+++ b/tmp/changes.txt
@@ -25,8 +25,6 @@ nvinfer1::ITensor* clamp_util(
return clamp_layer_out;
}
-
-
auto element_wise_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Signed-off-by: Bo Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/core/lowering/passes/rewrite_inputs_with_params.cpp b/tmp/changes.txt
index a05a905..c06fcd6 100644
--- a/home/runner/work/TensorRT/TensorRT/core/lowering/passes/rewrite_inputs_with_params.cpp
+++ b/tmp/changes.txt
@@ -1,13 +1,11 @@
-#include "torch/csrc/jit/ir/constants.h"
#include "core/util/prelude.h"
-
+#include "torch/csrc/jit/ir/constants.h"
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
-
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params) {
auto input_size = g->inputs().size();
auto param_it = params.rbegin();
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
Signed-off-by: Bo Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/tests/core/lowering/test_rewrite_inputs_with_params.cpp b/tmp/changes.txt
index 2f0341c..b0d8045 100644
--- a/home/runner/work/TensorRT/TensorRT/tests/core/lowering/test_rewrite_inputs_with_params.cpp
+++ b/tmp/changes.txt
@@ -25,7 +25,7 @@ TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) {
torch::jit::IValue param0 = torch::jit::IValue(0);
std::vector<torch::jit::IValue> params{param0};
torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params);
-
+
auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/tests/core/lowering/test_rewrite_inputs_with_params.cpp b/tmp/changes.txt
index 2f0341c..b0d8045 100644
--- a/home/runner/work/TensorRT/TensorRT/tests/core/lowering/test_rewrite_inputs_with_params.cpp
+++ b/tmp/changes.txt
@@ -25,7 +25,7 @@ TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) {
torch::jit::IValue param0 = torch::jit::IValue(0);
std::vector<torch::jit::IValue> params{param0};
torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params);
-
+
auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
ERROR: Some files do not conform to style guidelines
Signed-off-by: Bo Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
This PR should be good to merge. @narendasan @peri044 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
Signed-off-by: Bo Wang [email protected]
Description
In some cases, some parameters are graphs' input. We can rewrite them into constants to make sure it works.
Fixes #1190
Type of change
Checklist: