Skip to content

Commit a3e1093

Browse files
abhi-iyernarendasan
authored andcommitted
fix(): cleaned up logic, added case where bias doesn't exist for LSTM cell converter
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent a88cfaf commit a3e1093

File tree

1 file changed

+44
-50
lines changed

1 file changed

+44
-50
lines changed

Diff for: core/conversion/converters/impl/lstm_cell.cpp

+44-50
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,41 @@ namespace converters {
1414
namespace impl {
1515
namespace {
1616

17+
nvinfer1::ITensor* add_bias(nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) {
18+
auto a_dim = a->getDimensions();
19+
auto b_dim = b->getDimensions();
20+
21+
LOG_DEBUG(b_name << " tensor shape: " << b_dim);
22+
23+
TRTORCH_CHECK(util::broadcastable(a_dim, b_dim, false), "bias " << b_name << " is not broadcastable - can't be added to previous matmul operation.");
24+
25+
if (util::toVec(a_dim) != util::toVec(b_dim)) {
26+
LOG_DEBUG(b_name << "'s dimensions need to be reshaped");
27+
28+
auto shuffle = ctx->net->addShuffle(*b);
29+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
30+
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_dim), a_dim.nbDims));
31+
b = shuffle->getOutput(0);
32+
}
33+
34+
auto add = ctx->net->addElementWise(*a, *b, nvinfer1::ElementWiseOperation::kSUM);
35+
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
36+
37+
return add->getOutput(0);
38+
}
39+
1740
auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1841
.pattern({
1942
"aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)",
2043
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
2144
auto input = args[0].ITensorOrFreeze(ctx);
2245
auto w_ih = args[2].ITensorOrFreeze(ctx);
2346
auto w_hh = args[3].ITensorOrFreeze(ctx);
24-
auto b_ih = args[4].ITensorOrFreeze(ctx);
25-
auto b_hh = args[5].ITensorOrFreeze(ctx);
2647

2748
LOG_DEBUG("Input tensor shape: " << input->getDimensions());
2849
LOG_DEBUG("w_ih tensor shape: " << w_ih->getDimensions());
2950
LOG_DEBUG("w_hh tensor shape: " << w_hh->getDimensions());
30-
LOG_DEBUG("b_ih tensor shape: " << b_ih->getDimensions());
31-
LOG_DEBUG("b_hh tensor shape: " << b_hh->getDimensions());
32-
51+
3352
std::vector<nvinfer1::ITensor*> state;
3453
auto hx = args[1].IValue()->toListRef();
3554
for (unsigned int i = 0; i < hx.size(); i++) {
@@ -51,81 +70,56 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
5170
// calculate first half of gates
5271
auto mm1 = ctx->net->addMatrixMultiply(*input, nvinfer1::MatrixOperation::kNONE, *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE);
5372
TRTORCH_CHECK(mm1, "Unable to create matrix multiplication node: " << *n);
54-
5573
auto mm1_out = mm1->getOutput(0);
56-
auto mm1_dim = mm1_out->getDimensions();
57-
auto b_ih_dim = b_ih->getDimensions();
58-
59-
TRTORCH_CHECK(util::broadcastable(mm1_dim, b_ih_dim, false));
6074

61-
if (util::toVec(mm1_dim) != util::toVec(b_ih_dim)) {
62-
LOG_DEBUG("b_ih dimensions need to be reshaped");
63-
64-
auto shuffle = ctx->net->addShuffle(*b_ih);
65-
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
66-
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_ih_dim), mm1_dim.nbDims));
67-
b_ih = shuffle->getOutput(0);
68-
}
69-
70-
auto add1 = ctx->net->addElementWise(*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM);
71-
TRTORCH_CHECK(add1, "Unable to create ElementWise layer from node: " << *n);
72-
auto add1_out = add2->getOutput(0);
75+
auto out1 = !args[4].IValue()->isNone() ? add_bias(mm1_out, args[4].ITensorOrFreeze(ctx), "b_ih", ctx, n) : mm1_out;
7376

7477
// calculate second half of gates
75-
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOE);
78+
auto mm2 = ctx->net->addMatrixMultiply(*state[0], nvinfer1::MatrixOperation::kNONE, *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE);
7679
TRTORCH_CHECK(mm2, "Unable to create matrix multiplication node: " << *n);
77-
7880
auto mm2_out = mm2->getOutput(0);
79-
auto mm2_dim = mm2_out->getDimensions();
80-
auto b_hh_dim = b_hh->getDimensions();
81-
82-
TRTORCH_CHECK(util::broadcastable(mm2_dim, b_hh_dim, false));
8381

84-
if (util::toVec(mm2_dim) != util::toVec(b_hh_dim)) {
85-
LOG_DEBUG("b_hh dimensions need to be reshaped");
86-
87-
auto shuffle = ctx->net->addShuffle(*b_hh);
88-
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
89-
shuffle->setReshapeDimensions(util::toDimsPad(util::toVec(b_hh_dim), mm2_dim.nbDims));
90-
b_hh = shuffle->getOutput(0);
91-
}
92-
93-
auto add2 = ctx->net->addElementWise(*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM);
94-
TRTORCH_CHECK(add2, "Unable to create ElementWise layer from node: " << *n);
95-
auto add2_out = add2->getOutput(0);
82+
auto out2 = !args[5].IValue()->isNone() ? add_bias(mm2_out, args[5].ITensorOrFreeze(ctx), "b_hh", ctx, n) : mm2_out;
9683

9784
// gates
98-
auto add3 = ctx->net->addElementWise(*add1_out, *add2_out, nvinfer1::ElementWiseOperation::kSUM);
99-
TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n);
100-
auto add3_out = add3->getOutput(0);
85+
auto add = ctx->net->addElementWise(*out1, *out2, nvinfer1::ElementWiseOperation::kSUM);
86+
TRTORCH_CHECK(add, "Unable to create ElementWise layer from node: " << *n);
87+
auto add_out = add->getOutput(0);
10188

10289
// chunk Tensor into 4 parts and apply activation functions
103-
auto dims = util::toVec(add3_out->getDimensions());
90+
auto dims = util::toVec(add_out->getDimensions());
10491
auto batch = dims[0];
10592
auto hidden = dims[1]/4;
10693

107-
auto size = util::toDims(std::vector<int64_t>({batch, hidden}));
108-
auto stride = util::toDims(std::vector<int64_t>({1, 1}));
94+
std::vector<int64_t> size_vec = {batch, hidden};
95+
std::vector<int64_t> stride_vec = {1, 1};
96+
std::vector<int64_t> offset0 = {0, 0};
97+
std::vector<int64_t> offset1 = {0, hidden};
98+
std::vector<int64_t> offset2 = {0, 2*hidden};
99+
std::vector<int64_t> offset3 = {0, 3*hidden};
100+
101+
auto size = util::toDims(size_vec);
102+
auto stride = util::toDims(stride_vec);
109103

110-
auto slice1 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 0})), size, stride);
104+
auto slice1 = ctx->net->addSlice(*add_out, util::toDims(offset0), size, stride);
111105
TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n);
112106
auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
113107
TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n);
114108
auto ingate = activ1->getOutput(0);
115109

116-
auto slice2 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, hidden})), size, stride);
110+
auto slice2 = ctx->net->addSlice(*add_out, util::toDims(offset1), size, stride);
117111
TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n);
118112
auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
119113
TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n);
120114
auto forgetgate = activ2->getOutput(0);
121115

122-
auto slice3 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 2*hidden})), size, stride);
116+
auto slice3 = ctx->net->addSlice(*add_out, util::toDims(offset2), size, stride);
123117
TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n);
124118
auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH);
125119
TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n);
126120
auto cellgate = activ3->getOutput(0);
127121

128-
auto slice4 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 3*hidden})), size, stride);
122+
auto slice4 = ctx->net->addSlice(*add_out, util::toDims(offset3), size, stride);
129123
TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n);
130124
auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
131125
TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n);

0 commit comments

Comments
 (0)