1
+ #include " torch/torch.h"
2
+ #include " NvInfer.h"
3
+ #include " core/util/prelude.h"
4
+ #include " core/conversion/converters/converters.h"
5
+ #include " core/conversion/tensorcontainer/TensorContainer.h"
6
+
7
+ #include < ATen/ATen.h>
8
+ #include < vector>
9
+
10
+ namespace trtorch {
11
+ namespace core {
12
+ namespace conversion {
13
+ namespace converters {
14
+ namespace impl {
15
+ namespace {
16
+
17
+ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18
+ .pattern({
19
+ " aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)" ,
20
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21
+ auto input = args[0 ].ITensorOrFreeze (ctx);
22
+ auto w_ih = args[2 ].ITensorOrFreeze (ctx);
23
+ auto w_hh = args[3 ].ITensorOrFreeze (ctx);
24
+ auto b_ih = args[4 ].ITensorOrFreeze (ctx);
25
+ auto b_hh = args[5 ].ITensorOrFreeze (ctx);
26
+
27
+ LOG_DEBUG (" Input tensor shape: " << input->getDimensions ());
28
+ LOG_DEBUG (" w_ih tensor shape: " << w_ih->getDimensions ());
29
+ LOG_DEBUG (" w_hh tensor shape: " << w_hh->getDimensions ());
30
+ LOG_DEBUG (" b_ih tensor shape: " << b_ih->getDimensions ());
31
+ LOG_DEBUG (" b_hh tensor shape: " << b_hh->getDimensions ());
32
+
33
+ std::vector<nvinfer1::ITensor*> state;
34
+ auto hx = args[1 ].IValue ()->toListRef ();
35
+ for (unsigned int i = 0 ; i < hx.size (); i++) {
36
+ auto t = hx[i];
37
+
38
+ nvinfer1::ITensor* itensor;
39
+
40
+ if (t.isTensor ()) {
41
+ itensor = tensor_to_const (ctx, t.toTensor ());
42
+ } else {
43
+ auto cont = t.toCustomClass <TensorContainer>();
44
+ itensor = cont->tensor ();
45
+ }
46
+
47
+ LOG_DEBUG (" State tensor " << i << " shape: " << itensor->getDimensions ());
48
+ state.push_back (itensor);
49
+ }
50
+
51
+ // calculate first half of gates
52
+ auto mm1 = ctx->net ->addMatrixMultiply (*input, nvinfer1::MatrixOperation::kNONE , *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE );
53
+ TRTORCH_CHECK (mm1, " Unable to create matrix multiplication node: " << *n);
54
+
55
+ auto mm1_out = mm1->getOutput (0 );
56
+ auto mm1_dim = mm1_out->getDimensions ();
57
+ auto b_ih_dim = b_ih->getDimensions ();
58
+
59
+ TRTORCH_CHECK (util::broadcastable (mm1_dim, b_ih_dim, false ));
60
+
61
+ if (util::toVec (mm1_dim) != util::toVec (b_ih_dim)) {
62
+ LOG_DEBUG (" b_ih dimensions need to be reshaped" );
63
+
64
+ auto shuffle = ctx->net ->addShuffle (*b_ih);
65
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
66
+ shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_ih_dim), mm1_dim.nbDims ));
67
+ b_ih = shuffle->getOutput (0 );
68
+ }
69
+
70
+ auto add1 = ctx->net ->addElementWise (*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
71
+ TRTORCH_CHECK (add1, " Unable to create ElementWise layer from node: " << *n);
72
+ auto add1_out = add2->getOutput (0 );
73
+
74
+ // calculate second half of gates
75
+ auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOE );
76
+ TRTORCH_CHECK (mm2, " Unable to create matrix multiplication node: " << *n);
77
+
78
+ auto mm2_out = mm2->getOutput (0 );
79
+ auto mm2_dim = mm2_out->getDimensions ();
80
+ auto b_hh_dim = b_hh->getDimensions ();
81
+
82
+ TRTORCH_CHECK (util::broadcastable (mm2_dim, b_hh_dim, false ));
83
+
84
+ if (util::toVec (mm2_dim) != util::toVec (b_hh_dim)) {
85
+ LOG_DEBUG (" b_hh dimensions need to be reshaped" );
86
+
87
+ auto shuffle = ctx->net ->addShuffle (*b_hh);
88
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
89
+ shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_hh_dim), mm2_dim.nbDims ));
90
+ b_hh = shuffle->getOutput (0 );
91
+ }
92
+
93
+ auto add2 = ctx->net ->addElementWise (*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
94
+ TRTORCH_CHECK (add2, " Unable to create ElementWise layer from node: " << *n);
95
+ auto add2_out = add2->getOutput (0 );
96
+
97
+ // gates
98
+ auto add3 = ctx->net ->addElementWise (*add1_out, *add2_out, nvinfer1::ElementWiseOperation::kSUM );
99
+ TRTORCH_CHECK (add3, " Unable to create ElementWise layer from node: " << *n);
100
+ auto add3_out = add3->getOutput (0 );
101
+
102
+
103
+
104
+
105
+
106
+ auto mm_layer = ctx->net ->addMatrixMultiply (*self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
107
+ TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
108
+ mm_layer->setName (util::node_info (n).c_str ());
109
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
110
+
111
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
112
+ return true ;
113
+ }
114
+ });
115
+ } // namespace
116
+ } // namespace impl
117
+ } // namespace converters
118
+ } // namespace conversion
119
+ } // namespace core
120
+ } // namespace trtorch
0 commit comments