@@ -99,16 +99,57 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
99
99
TRTORCH_CHECK (add3, " Unable to create ElementWise layer from node: " << *n);
100
100
auto add3_out = add3->getOutput (0 );
101
101
102
-
103
-
104
-
105
-
106
- auto mm_layer = ctx->net ->addMatrixMultiply (*self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
107
- TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
108
- mm_layer->setName (util::node_info (n).c_str ());
109
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
102
+ // chunk Tensor into 4 parts and apply activation functions
103
+ auto dims = util::toVec (add3_out->getDimensions ());
104
+ auto batch = dims[0 ];
105
+ auto hidden = dims[1 ]/4 ;
106
+
107
+ auto size = util::toDims (std::vector<int64_t >({batch, hidden}));
108
+ auto stride = util::toDims (std::vector<int64_t >({1 , 1 }));
109
+
110
+ auto slice1 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 0 })), size, stride);
111
+ TRTORCH_CHECK (slice1, " Unable to create Slice layer from node: " << *n);
112
+ auto activ1 = ctx->net ->addActivation (*slice1->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
113
+ TRTORCH_CHECK (activ1, " Unable to create sigmoid activation layer from node: " << *n);
114
+ auto ingate = activ1->getOutput (0 );
115
+
116
+ auto slice2 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , hidden})), size, stride);
117
+ TRTORCH_CHECK (slice2, " Unable to create Slice layer from node: " << *n);
118
+ auto activ2 = ctx->net ->addActivation (*slice2->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
119
+ TRTORCH_CHECK (activ2, " Unable to create sigmoid activation layer from node: " << *n);
120
+ auto forgetgate = activ2->getOutput (0 );
121
+
122
+ auto slice3 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 2 *hidden})), size, stride);
123
+ TRTORCH_CHECK (slice3, " Unable to create Slice layer from node: " << *n);
124
+ auto activ3 = ctx->net ->addActivation (*slice3->getOutput (0 ), nvinfer1::ActivationType::kTANH );
125
+ TRTORCH_CHECK (activ3, " Unable to create tanh activation layer from node: " << *n);
126
+ auto cellgate = activ3->getOutput (0 );
127
+
128
+ auto slice4 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 3 *hidden})), size, stride);
129
+ TRTORCH_CHECK (slice4, " Unable to create Slice layer from node: " << *n);
130
+ auto activ4 = ctx->net ->addActivation (*slice4->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
131
+ TRTORCH_CHECK (activ4, " Unable to create sigmoid activation layer from node: " << *n);
132
+ auto outgate = activ4->getOutput (0 );
133
+
134
+ // compute cy
135
+ auto forget_cx = ctx->net ->addElementWise (*forgetgate, *state[1 ], nvinfer1::ElementWiseOperation::kPROD );
136
+ TRTORCH_CHECK (forget_cx, " Unable to create ElementWise layer from node: " << *n);
137
+ auto in_cell = ctx->net ->addElementWise (*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD );
138
+ TRTORCH_CHECK (in_cell, " Unable to create ElementWise layer from node: " << *n);
139
+ auto cy = ctx->net ->addElementWise (*forget_cx->getOutput (0 ), *in_cell->getOutput (0 ), nvinfer1::ElementWiseOperation::kPROD );
140
+ TRTORCH_CHECK (cy, " Unable to create ElementWise layer from node: " << *n);
141
+ auto cy_out = ctx->AssociateValueAndTensor (n->outputs ()[1 ], cy->getOutput (0 ));
142
+
143
+ // compute hy
144
+ auto cy_tanh = ctx->net ->addActivation (*cy_out, nvinfer1::ActivationType::kTANH );
145
+ TRTORCH_CHECK (cy_tanh, " Unable to create tanh activation layer from node: " << *n);
146
+ auto hy = ctx->net ->addElementWise (*outgate, *cy_tanh->getOutput (0 ), nvinfer1::ElementWiseOperation::kPROD );
147
+ TRTORCH_CHECK (hy, " Unable to create ElementWise layer from node: " << *n);
148
+ auto hy_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], hy->getOutput (0 ));
149
+
150
+ LOG_DEBUG (" Output tensor [hy] shape: " << hy_out->getDimensions ());
151
+ LOG_DEBUG (" Output tensor [cy] shape: " << cy_out->getDimensions ());
110
152
111
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
112
153
return true ;
113
154
}
114
155
});
0 commit comments