Skip to content

Commit 7b71dd9

Browse files
committed
replace cumsum_plugin with native TRT
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 5167c47 commit 7b71dd9

File tree

5 files changed

+83
-396
lines changed

5 files changed

+83
-396
lines changed

core/conversion/converters/impl/cumsum.cpp

+36-19
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "core/conversion/tensorcontainer/TensorContainer.h"
44
#include "core/util/prelude.h"
55
#include "core/util/trt_util.h"
6-
#include "plugins/cumsum_plugin.h"
76
#include "torch/torch.h"
87

98
#include <ATen/ATen.h>
@@ -16,26 +15,10 @@ namespace converters {
1615
namespace impl {
1716
namespace {
1817

19-
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name, int dim) {
20-
LOG_WARNING("Cumsum layer will be run through ATen, not TensorRT. Performance may be lower than expected");
21-
22-
auto creator = new plugins::CumsumPluginCreator();
23-
auto plugin = creator->createPlugin(name, dim);
24-
25-
auto cumsum_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
26-
TRTORCH_CHECK(cumsum_layer, "Unable to create cumsum plugin from node" << *n);
27-
28-
cumsum_layer->setName(util::node_info(n).c_str());
29-
30-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], cumsum_layer->getOutput(0));
31-
32-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
33-
}
34-
3518
auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
3619
{"aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)",
3720
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
38-
auto in = args[0].ITensor();
21+
auto in = args[0].ITensorOrFreeze(ctx);
3922
auto input_dims = in->getDimensions();
4023
int dim = args[1].unwrapToInt();
4124
TRTORCH_CHECK(
@@ -45,7 +28,41 @@ auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
4528
if (dim < 0) {
4629
dim += input_dims.nbDims;
4730
}
48-
create_plugin(ctx, n, in, "Cumsum", dim);
31+
32+
// Scan through each slice across summation axis and add it to the running sum
33+
auto loop = ctx->net->addLoop();
34+
nvinfer1::ITensor* tripLimit = NULL;
35+
if (input_dims.d[dim] > 0) {
36+
torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32);
37+
tripLimit = tensor_to_const(ctx, axis);
38+
} else {
39+
nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0);
40+
torch::Tensor dimValue = torch::tensor(dim, torch::kInt32);
41+
nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue);
42+
tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0);
43+
}
44+
45+
loop->addTripLimit(*tripLimit, nvinfer1::TripLimit::kCOUNT);
46+
47+
auto iterator = loop->addIterator(*in, dim, false);
48+
auto data = iterator->getOutput(0);
49+
auto newDims = data->getDimensions();
50+
51+
torch::Tensor zeroValue = at::full(util::toVec(newDims), 0, torch::kFloat32);
52+
auto zeroTensor = tensor_to_const(ctx, zeroValue);
53+
auto runningSum = loop->addRecurrence(*zeroTensor);
54+
auto runningSumTensor = runningSum->getOutput(0);
55+
56+
auto curSum = ctx->net->addElementWise(*data, *runningSumTensor, nvinfer1::ElementWiseOperation::kSUM);
57+
runningSum->setInput(1, *curSum->getOutput(0));
58+
59+
nvinfer1::ILoopOutputLayer* loopOut =
60+
loop->addLoopOutput(*curSum->getOutput(0), nvinfer1::LoopOutput::kCONCATENATE, dim);
61+
loopOut->setInput(1, *tripLimit);
62+
63+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], loopOut->getOutput(0));
64+
65+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
4966
return true;
5067
}});
5168

core/conversion/converters/impl/plugins/BUILD

+3-5
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@ config_setting(
1010
cc_library(
1111
name = "plugins",
1212
hdrs = [
13-
"interpolate_plugin.h",
14-
"cumsum_plugin.h"
13+
"interpolate_plugin.h"
1514
],
1615
srcs = [
17-
"interpolate_plugin.cpp",
18-
"cumsum_plugin.cpp"
16+
"interpolate_plugin.cpp"
1917
],
2018
deps = [
2119
"@tensorrt//:nvinfer",
@@ -39,5 +37,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3937
pkg_tar(
4038
name = "include",
4139
package_dir = "core/conversion/converters/impl/plugins",
42-
srcs = ["interpolate_plugin.h", "cumsum_plugin.h"],
40+
srcs = ["interpolate_plugin.h"],
4341
)

core/conversion/converters/impl/plugins/cumsum_plugin.cpp

-237
This file was deleted.

0 commit comments

Comments
 (0)