Skip to content

Commit a88cfaf

Browse files
abhi-iyernarendasan
authored andcommitted
feat(): finished logic for LSTM cell, now to test
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 546d790 commit a88cfaf

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

Diff for: core/conversion/converters/impl/lstm_cell.cpp

+50-9
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,57 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
9999
TRTORCH_CHECK(add3, "Unable to create ElementWise layer from node: " << *n);
100100
auto add3_out = add3->getOutput(0);
101101

102-
103-
104-
105-
106-
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
107-
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
108-
mm_layer->setName(util::node_info(n).c_str());
109-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
102+
// chunk Tensor into 4 parts and apply activation functions
103+
auto dims = util::toVec(add3_out->getDimensions());
104+
auto batch = dims[0];
105+
auto hidden = dims[1]/4;
106+
107+
auto size = util::toDims(std::vector<int64_t>({batch, hidden}));
108+
auto stride = util::toDims(std::vector<int64_t>({1, 1}));
109+
110+
auto slice1 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 0})), size, stride);
111+
TRTORCH_CHECK(slice1, "Unable to create Slice layer from node: " << *n);
112+
auto activ1 = ctx->net->addActivation(*slice1->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
113+
TRTORCH_CHECK(activ1, "Unable to create sigmoid activation layer from node: " << *n);
114+
auto ingate = activ1->getOutput(0);
115+
116+
auto slice2 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, hidden})), size, stride);
117+
TRTORCH_CHECK(slice2, "Unable to create Slice layer from node: " << *n);
118+
auto activ2 = ctx->net->addActivation(*slice2->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
119+
TRTORCH_CHECK(activ2, "Unable to create sigmoid activation layer from node: " << *n);
120+
auto forgetgate = activ2->getOutput(0);
121+
122+
auto slice3 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 2*hidden})), size, stride);
123+
TRTORCH_CHECK(slice3, "Unable to create Slice layer from node: " << *n);
124+
auto activ3 = ctx->net->addActivation(*slice3->getOutput(0), nvinfer1::ActivationType::kTANH);
125+
TRTORCH_CHECK(activ3, "Unable to create tanh activation layer from node: " << *n);
126+
auto cellgate = activ3->getOutput(0);
127+
128+
auto slice4 = ctx->net->addSlice(*add3_out, util::toDims(std::vector<int64_t>({0, 3*hidden})), size, stride);
129+
TRTORCH_CHECK(slice4, "Unable to create Slice layer from node: " << *n);
130+
auto activ4 = ctx->net->addActivation(*slice4->getOutput(0), nvinfer1::ActivationType::kSIGMOID);
131+
TRTORCH_CHECK(activ4, "Unable to create sigmoid activation layer from node: " << *n);
132+
auto outgate = activ4->getOutput(0);
133+
134+
// compute cy
135+
auto forget_cx = ctx->net->addElementWise(*forgetgate, *state[1], nvinfer1::ElementWiseOperation::kPROD);
136+
TRTORCH_CHECK(forget_cx, "Unable to create ElementWise layer from node: " << *n);
137+
auto in_cell = ctx->net->addElementWise(*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD);
138+
TRTORCH_CHECK(in_cell, "Unable to create ElementWise layer from node: " << *n);
139+
auto cy = ctx->net->addElementWise(*forget_cx->getOutput(0), *in_cell->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
140+
TRTORCH_CHECK(cy, "Unable to create ElementWise layer from node: " << *n);
141+
auto cy_out = ctx->AssociateValueAndTensor(n->outputs()[1], cy->getOutput(0));
142+
143+
// compute hy
144+
auto cy_tanh = ctx->net->addActivation(*cy_out, nvinfer1::ActivationType::kTANH);
145+
TRTORCH_CHECK(cy_tanh, "Unable to create tanh activation layer from node: " << *n);
146+
auto hy = ctx->net->addElementWise(*outgate, *cy_tanh->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
147+
TRTORCH_CHECK(hy, "Unable to create ElementWise layer from node: " << *n);
148+
auto hy_out = ctx->AssociateValueAndTensor(n->outputs()[0], hy->getOutput(0));
149+
150+
LOG_DEBUG("Output tensor [hy] shape: " << hy_out->getDimensions());
151+
LOG_DEBUG("Output tensor [cy] shape: " << cy_out->getDimensions());
110152

111-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
112153
return true;
113154
}
114155
});

0 commit comments

Comments
 (0)