Skip to content

Commit 7be368f

Browse files
authored
Merge pull request #372 from guoruoqian/cumsum
support cumsum converter
2 parents c06db43 + 2432fb8 commit 7be368f

File tree

5 files changed

+175
-1
lines changed

5 files changed

+175
-1
lines changed

Diff for: core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"impl/concat.cpp",
3737
"impl/constant.cpp",
3838
"impl/conv_deconv.cpp",
39+
"impl/cumsum.cpp",
3940
"impl/element_wise.cpp",
4041
"impl/expand.cpp",
4142
"impl/interpolate.cpp",

Diff for: core/conversion/converters/impl/cumsum.cpp

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "core/util/trt_util.h"
6+
#include "torch/torch.h"
7+
8+
#include <ATen/ATen.h>
9+
#include <vector>
10+
11+
namespace trtorch {
12+
namespace core {
13+
namespace conversion {
14+
namespace converters {
15+
namespace impl {
16+
namespace {
17+
18+
auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
19+
{"aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)",
20+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21+
auto in = args[0].ITensorOrFreeze(ctx);
22+
auto input_dims = in->getDimensions();
23+
int dim = args[1].unwrapToInt();
24+
TRTORCH_CHECK(
25+
(dim >= 0 && dim < input_dims.nbDims) || (dim < 0 && (input_dims.nbDims + dim >= 0)),
26+
"Dimension out of range (expected to be in range of [" << -input_dims.nbDims << ", " << input_dims.nbDims - 1
27+
<< "], but got " << dim << ")");
28+
if (dim < 0) {
29+
dim += input_dims.nbDims;
30+
}
31+
32+
// Scan through each slice across summation axis and add it to the running sum
33+
auto loop = ctx->net->addLoop();
34+
nvinfer1::ITensor* tripLimit = NULL;
35+
if (input_dims.d[dim] > 0) {
36+
torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32);
37+
tripLimit = tensor_to_const(ctx, axis);
38+
} else {
39+
nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0);
40+
torch::Tensor dimValue = torch::tensor(dim, torch::kInt32);
41+
nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue);
42+
tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0);
43+
}
44+
45+
loop->addTripLimit(*tripLimit, nvinfer1::TripLimit::kCOUNT);
46+
47+
auto iterator = loop->addIterator(*in, dim, false);
48+
auto data = iterator->getOutput(0);
49+
auto newDims = data->getDimensions();
50+
51+
torch::Tensor zeroValue = at::full(util::toVec(newDims), 0, torch::kFloat32);
52+
auto zeroTensor = tensor_to_const(ctx, zeroValue);
53+
auto runningSum = loop->addRecurrence(*zeroTensor);
54+
auto runningSumTensor = runningSum->getOutput(0);
55+
56+
auto curSum = ctx->net->addElementWise(*data, *runningSumTensor, nvinfer1::ElementWiseOperation::kSUM);
57+
runningSum->setInput(1, *curSum->getOutput(0));
58+
59+
nvinfer1::ILoopOutputLayer* loopOut =
60+
loop->addLoopOutput(*curSum->getOutput(0), nvinfer1::LoopOutput::kCONCATENATE, dim);
61+
loopOut->setInput(1, *tripLimit);
62+
63+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], loopOut->getOutput(0));
64+
65+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
66+
return true;
67+
}});
68+
69+
} // namespace
70+
} // namespace impl
71+
} // namespace converters
72+
} // namespace conversion
73+
} // namespace core
74+
} // namespace trtorch

Diff for: tests/core/conversion/converters/BUILD

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ converter_test(
2323
name = "test_conv_deconv",
2424
)
2525

26+
converter_test(
27+
name = "test_cumsum"
28+
)
29+
2630
converter_test(
2731
name = "test_element_wise",
2832
)
@@ -96,7 +100,9 @@ test_suite(
96100
tests = [
97101
":test_activation",
98102
":test_batch_norm",
103+
":test_concat",
99104
":test_conv_deconv",
105+
":test_cumsum",
100106
":test_element_wise",
101107
":test_expand",
102108
":test_interpolate",

Diff for: tests/core/conversion/converters/test_concat.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
3131
TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
3232
const auto graph = R"IR(
3333
graph(%0 : Tensor,
34-
%1 : Float(5:1)):
34+
%1 : Float(5)):
3535
%2 : Tensor[] = prim::ListConstruct(%0, %1)
3636
%3 : int = prim::Constant[value=0]()
3737
%4 : Tensor = aten::cat(%2, %3)

Diff for: tests/core/conversion/converters/test_cumsum.cpp

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenCumsumConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=1]()
11+
%2 : None = prim::Constant()
12+
%3 : Tensor = aten::cumsum(%0, %1, %2)
13+
return (%3))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, &*g);
17+
18+
auto in = at::randint(-5, 5, {2, 3, 4}, {at::kCUDA});
19+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
20+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
21+
22+
in = at::clone(in);
23+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
25+
26+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
27+
}
28+
29+
TEST(Converters, ATenCumsumConvertsCorrectlyWithDynamicInput) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor):
32+
%1 : int = prim::Constant[value=1]()
33+
%2 : None = prim::Constant()
34+
%3 : Tensor = aten::cumsum(%0, %1, %2)
35+
return (%3))IR";
36+
37+
auto g = std::make_shared<torch::jit::Graph>();
38+
torch::jit::parseIR(graph, &*g);
39+
40+
auto in = at::randint(-5, 5, {2, 3, 4}, {at::kCUDA});
41+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
42+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
43+
44+
in = at::clone(in);
45+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
46+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
47+
48+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
49+
}
50+
51+
TEST(Converters, ATenCumsumNegativeDimConvertsCorrectly) {
52+
const auto graph = R"IR(
53+
graph(%0 : Tensor):
54+
%1 : int = prim::Constant[value=-2]()
55+
%2 : None = prim::Constant()
56+
%3 : Tensor = aten::cumsum(%0, %1, %2)
57+
return (%3))IR";
58+
59+
auto g = std::make_shared<torch::jit::Graph>();
60+
torch::jit::parseIR(graph, &*g);
61+
62+
auto in = at::randint(-5, 5, {2, 3, 4}, {at::kCUDA});
63+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
64+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
65+
66+
in = at::clone(in);
67+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
68+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
69+
70+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
71+
}
72+
73+
TEST(Converters, ATenCumsumNegativeDimConvertsCorrectlyWithDynamicInput) {
74+
const auto graph = R"IR(
75+
graph(%0 : Tensor):
76+
%1 : int = prim::Constant[value=-2]()
77+
%2 : None = prim::Constant()
78+
%3 : Tensor = aten::cumsum(%0, %1, %2)
79+
return (%3))IR";
80+
81+
auto g = std::make_shared<torch::jit::Graph>();
82+
torch::jit::parseIR(graph, &*g);
83+
84+
auto in = at::randint(-5, 5, {2, 3, 4}, {at::kCUDA});
85+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
86+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
87+
88+
in = at::clone(in);
89+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
90+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
91+
92+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
93+
}

0 commit comments

Comments
 (0)