Skip to content

Commit 546d790

Browse files
abhi-iyernarendasan
authored andcommitted
feat(): started working on lstm_cell converter
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 5a105c6 commit 546d790

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

Diff for: core/conversion/converters/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ cc_library(
4848
"impl/unary.cpp",
4949
"impl/interpolate.cpp",
5050
"impl/select.cpp",
51-
"impl/stack.cpp"
51+
"impl/stack.cpp",
52+
"impl/lstm_cell.cpp"
5253
],
5354
deps = [
5455
"@tensorrt//:nvinfer",

Diff for: core/conversion/converters/impl/lstm_cell.cpp

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "torch/torch.h"
2+
#include "NvInfer.h"
3+
#include "core/util/prelude.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/conversion/tensorcontainer/TensorContainer.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18+
.pattern({
19+
"aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)",
20+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21+
auto input = args[0].ITensorOrFreeze(ctx);
22+
auto w_ih = args[2].ITensorOrFreeze(ctx);
23+
auto w_hh = args[3].ITensorOrFreeze(ctx);
24+
auto b_ih = args[4].ITensorOrFreeze(ctx);
25+
auto b_hh = args[5].ITensorOrFreeze(ctx);
26+
27+
LOG_DEBUG("Input tensor shape: " << input->getDimensions());
28+
LOG_DEBUG("w_ih tensor shape: " << w_ih->getDimensions());
29+
LOG_DEBUG("w_hh tensor shape: " << w_hh->getDimensions());
30+
LOG_DEBUG("b_ih tensor shape: " << b_ih->getDimensions());
31+
LOG_DEBUG("b_hh tensor shape: " << b_hh->getDimensions());
32+
33+
std::vector<nvinfer1::ITensor*> state;
34+
auto hx = args[1].IValue()->toListRef();
35+
for (unsigned int i = 0; i < hx.size(); i++) {
36+
auto t = hx[i];
37+
38+
nvinfer1::ITensor* itensor;
39+
40+
if (t.isTensor()) {
41+
itensor = tensor_to_const(ctx, t.toTensor());
42+
} else {
43+
auto cont = t.toCustomClass<TensorContainer>();
44+
itensor = cont->tensor();
45+
}
46+
47+
LOG_DEBUG("State tensor " << i << " shape: " << itensor->getDimensions());
48+
state.push_back(itensor);
49+
}
50+
51+
// calculate first half of gates
52+
auto mm1 = ctx->net->addMatrixMultiply(*input, nvinfer1::MatrixOperation::kNONE, *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE);
53+
TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n);
54+
55+
auto mm1_out = mm1->getOutput(0);
56+
auto mm1_dim = mm1_out->getDimensions();
57+
auto b_ih_dim = b_ih->getDimensions();
58+
59+
TRTORCH_CHECK(util::broadcastable(mm1_dim, b_ih_dim, false));
60+
61+
if (util::toVec(mm1_dim) != util::toVec(b_ih_dim)) {
62+
LOG_DEBUG("b_ih dimensions need to be reshaped");
63+
64+
auto shuffle = ctx->net->addShuffle(*b_ih);
65+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
66+
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_ih_dim), mm1_dim.nbDims));
67+
b_ih = shuffle->getOutput(0);
68+
}
69+
70+
auto add1 = ctx->net->addElementWise(*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM);
71+
TRTORCH_CHECK(add1, "Unable to create ElementWise layer from node: " << *n);
72+
auto add1_out = add2->getOutput(0);
73+
74+
// calculate second half of gates
75+
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOE);
76+
TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n);
77+
78+
auto mm2_out = mm2->getOutput(0);
79+
auto mm2_dim = mm2_out->getDimensions();
80+
auto b_hh_dim = b_hh->getDimensions();
81+
82+
TRTORCH_CHECK(util::broadcastable(mm2_dim, b_hh_dim, false));
83+
84+
if (util::toVec(mm2_dim) != util::toVec(b_hh_dim)) {
85+
LOG_DEBUG("b_hh dimensions need to be reshaped");
86+
87+
auto shuffle = ctx->net->addShuffle(*b_hh);
88+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
89+
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_hh_dim), mm2_dim.nbDims));
90+
b_hh = shuffle->getOutput(0);
91+
}
92+
93+
auto add2 = ctx->net->addElementWise(*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM);
94+
TRTORCH_CHECK(add2, "Unable to create ElementWise layer from node: " << *n);
95+
auto add2_out = add2->getOutput(0);
96+
97+
// gates
98+
auto add3 = ctx->net->addElementWise(*add1_out, *add2_out, nvinfer1::ElementWiseOperation::kSUM);
99+
TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n);
100+
auto add3_out = add3->getOutput(0);
101+
102+
103+
104+
105+
106+
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
107+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
108+
mm_layer->setName(util::node_info(n).c_str());
109+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
110+
111+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
112+
return true;
113+
}
114+
});
115+
} // namespace
116+
} // namespace impl
117+
} // namespace converters
118+
} // namespace conversion
119+
} // namespace core
120+
} // namespace trtorch

0 commit comments

Comments
 (0)