3
3
#include " core/conversion/tensorcontainer/TensorContainer.h"
4
4
#include " core/util/prelude.h"
5
5
#include " core/util/trt_util.h"
6
- #include " plugins/cumsum_plugin.h"
7
6
#include " torch/torch.h"
8
7
9
8
#include < ATen/ATen.h>
@@ -16,26 +15,10 @@ namespace converters {
16
15
namespace impl {
17
16
namespace {
18
17
19
- void create_plugin (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char * name, int dim) {
20
- LOG_WARNING (" Cumsum layer will be run through ATen, not TensorRT. Performance may be lower than expected" );
21
-
22
- auto creator = new plugins::CumsumPluginCreator ();
23
- auto plugin = creator->createPlugin (name, dim);
24
-
25
- auto cumsum_layer = ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(&in), 1 , *plugin);
26
- TRTORCH_CHECK (cumsum_layer, " Unable to create cumsum plugin from node" << *n);
27
-
28
- cumsum_layer->setName (util::node_info (n).c_str ());
29
-
30
- auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], cumsum_layer->getOutput (0 ));
31
-
32
- LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
33
- }
34
-
35
18
auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
36
19
{" aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)" ,
37
20
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
38
- auto in = args[0 ].ITensor ( );
21
+ auto in = args[0 ].ITensorOrFreeze (ctx );
39
22
auto input_dims = in->getDimensions ();
40
23
int dim = args[1 ].unwrapToInt ();
41
24
TRTORCH_CHECK (
@@ -45,7 +28,41 @@ auto cumsum_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
45
28
if (dim < 0 ) {
46
29
dim += input_dims.nbDims ;
47
30
}
48
- create_plugin (ctx, n, in, " Cumsum" , dim);
31
+
32
+ // Scan through each slice across summation axis and add it to the running sum
33
+ auto loop = ctx->net ->addLoop ();
34
+ nvinfer1::ITensor* tripLimit = NULL ;
35
+ if (input_dims.d [dim] > 0 ) {
36
+ torch::Tensor axis = torch::tensor (input_dims.d [dim], torch::kInt32 );
37
+ tripLimit = tensor_to_const (ctx, axis);
38
+ } else {
39
+ nvinfer1::ITensor* inpShape = ctx->net ->addShape (*in)->getOutput (0 );
40
+ torch::Tensor dimValue = torch::tensor (dim, torch::kInt32 );
41
+ nvinfer1::ITensor* axis = tensor_to_const (ctx, dimValue);
42
+ tripLimit = ctx->net ->addGather (*inpShape, *axis, 0 )->getOutput (0 );
43
+ }
44
+
45
+ loop->addTripLimit (*tripLimit, nvinfer1::TripLimit::kCOUNT );
46
+
47
+ auto iterator = loop->addIterator (*in, dim, false );
48
+ auto data = iterator->getOutput (0 );
49
+ auto newDims = data->getDimensions ();
50
+
51
+ torch::Tensor zeroValue = at::full (util::toVec (newDims), 0 , torch::kFloat32 );
52
+ auto zeroTensor = tensor_to_const (ctx, zeroValue);
53
+ auto runningSum = loop->addRecurrence (*zeroTensor);
54
+ auto runningSumTensor = runningSum->getOutput (0 );
55
+
56
+ auto curSum = ctx->net ->addElementWise (*data, *runningSumTensor, nvinfer1::ElementWiseOperation::kSUM );
57
+ runningSum->setInput (1 , *curSum->getOutput (0 ));
58
+
59
+ nvinfer1::ILoopOutputLayer* loopOut =
60
+ loop->addLoopOutput (*curSum->getOutput (0 ), nvinfer1::LoopOutput::kCONCATENATE , dim);
61
+ loopOut->setInput (1 , *tripLimit);
62
+
63
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], loopOut->getOutput (0 ));
64
+
65
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
49
66
return true ;
50
67
}});
51
68
0 commit comments