Skip to content

Commit a801506

Browse files
committed
feat(//core/converters): Add power layer conversion support and minor README edits
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent e4a4574 commit a801506

File tree

4 files changed

+88
-19
lines changed

4 files changed

+88
-19
lines changed

Diff for: README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ then you have two options.
115115
1. You need to download the tarball distributions of TensorRT and cuDNN from the NVIDIA website.
116116
- https://developer.nvidia.com/cudnn
117117
- https://developer.nvidia.com/tensorrt
118-
2. Place these files in a directory (the directories `third_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose)
118+
2. Place these files in a directory (the directories `third_party/dist_dir/[x86_64-linux-gnu | aarch64-linux-gnu]` exist for this purpose)
119119
3. Compile using:
120120
``` shell
121-
bazel build //:libtrtorch --compilation_mode opt --distdir third_party/distdir/[x86_64-linux-gnu | aarch64-linux-gnu]
121+
bazel build //:libtrtorch --compilation_mode opt --distdir third_party/dist_dir/[x86_64-linux-gnu | aarch64-linux-gnu]
122122
```
123123

124124
#### 2. Building using locally installed cuDNN & TensorRT
@@ -175,7 +175,7 @@ bazel build //:libtrtorch --compilation_mode=dbg
175175

176176
### Native compilation on NVIDIA Jetson AGX
177177
``` shell
178-
bazel build //:libtrtorch --distdir third_party/distdir/aarch64-linux-gnu
178+
bazel build //:libtrtorch --distdir third_party/dist_dir/aarch64-linux-gnu
179179
```
180180
> Note: Please refer [installation](docs/tutorials/installation.html) instructions for Pre-requisites
181181

Diff for: core/conversion/converters/impl/element_wise.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/torch.h>
12
#include "core/util/prelude.h"
23
#include "core/conversion/converters/converters.h"
34

@@ -180,6 +181,50 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
180181
mul->setName(util::node_info(n).c_str());
181182
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
182183

184+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
185+
return true;
186+
}
187+
}).pattern({
188+
"aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
189+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
190+
// TODO: Remove with functionalization
191+
auto self = args[0].ITensorOrFreeze(ctx);
192+
auto exponent = args[1].ITensorOrFreeze(ctx);
193+
auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n));
194+
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);
195+
196+
pow->setName(util::node_info(n).c_str());
197+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
198+
199+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
200+
return true;
201+
}
202+
}).pattern({
203+
"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
204+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205+
auto self = args[0].ITensorOrFreeze(ctx);
206+
auto exponentScalar = args[1].unwrapToScalar().to<float>();
207+
208+
// Calculate size of the input and define an exponent tensor of the same size
209+
int volume = 1;
210+
for (int i = 0; i < self->getDimensions().nbDims; i++) {
211+
volume = volume * (self->getDimensions().d[i]);
212+
}
213+
214+
// Create a torch tensor with constant exponent values
215+
LOG_DEBUG("Broadcasting the exponent in power layer");
216+
torch::Tensor exponentBlob = torch::full({volume}, exponentScalar);
217+
218+
// Create a corresponding constant layer in TRT and get the layer output.
219+
auto weights = converters::Weights(ctx, exponentBlob);
220+
auto exponentTensor = ctx->net->addConstant(self->getDimensions(), weights.data)->getOutput(0);
221+
222+
auto pow = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponentTensor, util::node_info(n));
223+
TRTORCH_CHECK(pow, "Unable to create Power layer from node: " << *n);
224+
225+
pow->setName(util::node_info(n).c_str());
226+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
227+
183228
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
184229
return true;
185230
}

Diff for: py/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ traced_model = torch.jit.trace(model, [data])
2222

2323
# Compile module
2424
compiled_trt_model = trtorch.compile(model, {
25-
"input_shape": [data.shape],
25+
"input_shapes": [data.shape],
2626
"op_precision": torch.half, # Run in FP16
2727
})
2828

Diff for: tests/core/converters/test_element_wise.cpp

+39-15
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,39 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7-
void pointwise_test_helper(std::string graph_ir) {
7+
void pointwise_test_helper(std::string graph_ir, bool singleInput) {
88
auto g = std::make_shared<torch::jit::Graph>();
99
torch::jit::parseIR(graph_ir, &*g);
10-
11-
auto in0 = at::randint(1, 5, {5}, {at::kCUDA});
12-
auto in1 = at::randint(1, 5, {5}, {at::kCUDA});
10+
11+
// singleInput case is enabled when elementwise operation is performed
12+
// with an input and a constant embedded in graph
13+
std::vector<at::Tensor> torch_inputs;
14+
torch_inputs.push_back(at::randint(1, 5, {5}, {at::kCUDA}));
15+
if (!singleInput) {
16+
torch_inputs.push_back(at::randint(1, 5, {5}, {at::kCUDA}));
17+
}
1318
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
14-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0, in1});
19+
auto jit_results = trtorch::tests::util::RunGraph(g, params, torch_inputs);
20+
21+
std::vector<at::Tensor> trt_inputs;
22+
for (auto in : torch_inputs) {
23+
trt_inputs.push_back(at::clone(in));
24+
}
1525

16-
in0 = at::clone(in0);
17-
in1 = at::clone(in1);
1826
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
19-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1});
27+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, trt_inputs);
2028

2129
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2230
}
2331

2432

25-
2633
TEST(Converters, ATenAddConvertsCorrectly) {
2734
const auto graph = R"IR(
2835
graph(%0 : Tensor, %1 : Tensor):
2936
%2 : int = prim::Constant[value=1]()
3037
%3 : Tensor = aten::add(%0, %1, %2)
3138
return (%3))IR";
32-
pointwise_test_helper(graph);
39+
pointwise_test_helper(graph, false);
3340
}
3441

3542

@@ -39,7 +46,7 @@ TEST(Converters, ATenAddConvertsCorrectly) {
3946
// %2 : int = prim::Constant[value=2]()
4047
// %3 : Tensor = aten::add(%0, %1, %2)
4148
// return (%3))IR";
42-
// pointwise_test_helper(graph);
49+
// pointwise_test_helper(graph, false);
4350
// }
4451

4552
TEST(Converters, ATenSubConvertsCorrectly) {
@@ -48,7 +55,7 @@ TEST(Converters, ATenSubConvertsCorrectly) {
4855
%2 : int = prim::Constant[value=1]()
4956
%3 : Tensor = aten::sub(%0, %1, %2)
5057
return (%3))IR";
51-
pointwise_test_helper(graph);
58+
pointwise_test_helper(graph, false);
5259
}
5360

5461
// TEST(Converters, ATenSubWithScaleConvertsCorrectly) {
@@ -57,21 +64,38 @@ TEST(Converters, ATenSubConvertsCorrectly) {
5764
// %2 : float = prim::Constant[value=0.5]()
5865
// %3 : Tensor = aten::add(%0, %1, %2)
5966
// return (%3))IR";
60-
// pointwise_test_helper(graph);
67+
// pointwise_test_helper(graph, false);
6168
// }
6269

6370
TEST(Converters, ATenMulConvertsCorrectly) {
6471
const auto graph = R"IR(
6572
graph(%0 : Tensor, %1 : Tensor):
6673
%2 : Tensor = aten::mul(%0, %1)
6774
return (%2))IR";
68-
pointwise_test_helper(graph);
75+
pointwise_test_helper(graph, false);
6976
}
7077

7178
TEST(Converters, ATenDivConvertsCorrectly) {
7279
const auto graph = R"IR(
7380
graph(%0 : Tensor, %1 : Tensor):
7481
%2 : Tensor = aten::div(%0, %1)
7582
return (%2))IR";
76-
pointwise_test_helper(graph);
83+
pointwise_test_helper(graph, false);
84+
}
85+
86+
TEST(Converters, ATenPowTensorConvertsCorrectly) {
87+
const auto graph = R"IR(
88+
graph(%x.1 : Tensor, %x2.1 : Tensor):
89+
%3 : Tensor = aten::pow(%x.1, %x2.1)
90+
return (%3))IR";
91+
pointwise_test_helper(graph, false);
92+
}
93+
94+
TEST(Converters, ATenPowScalarConvertsCorrectly) {
95+
const auto graph = R"IR(
96+
graph(%x.1 : Tensor):
97+
%2 : int = prim::Constant[value=2]()
98+
%3 : Tensor = aten::pow(%x.1, %2)
99+
return (%3))IR";
100+
pointwise_test_helper(graph, true);
77101
}

0 commit comments

Comments
 (0)