@@ -14,22 +14,41 @@ namespace converters {
14
14
namespace impl {
15
15
namespace {
16
16
17
+ nvinfer1::ITensor* add_bias (nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) {
18
+ auto a_dim = a->getDimensions ();
19
+ auto b_dim = b->getDimensions ();
20
+
21
+ LOG_DEBUG (b_name << " tensor shape: " << b_dim);
22
+
23
+ TRTORCH_CHECK (util::broadcastable (a_dim, b_dim, false ), " bias " << b_name << " is not broadcastable - can't be added to previous matmul operation." );
24
+
25
+ if (util::toVec (a_dim) != util::toVec (b_dim)) {
26
+ LOG_DEBUG (b_name << " 's dimensions need to be reshaped" );
27
+
28
+ auto shuffle = ctx->net ->addShuffle (*b);
29
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
30
+ shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_dim), a_dim.nbDims ));
31
+ b = shuffle->getOutput (0 );
32
+ }
33
+
34
+ auto add = ctx->net ->addElementWise (*a, *b, nvinfer1::ElementWiseOperation::kSUM );
35
+ TRTORCH_CHECK (add, " Unable to create ElementWise layer from node: " << *n);
36
+
37
+ return add->getOutput (0 );
38
+ }
39
+
17
40
auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
18
41
.pattern({
19
42
" aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)" ,
20
43
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
21
44
auto input = args[0 ].ITensorOrFreeze (ctx);
22
45
auto w_ih = args[2 ].ITensorOrFreeze (ctx);
23
46
auto w_hh = args[3 ].ITensorOrFreeze (ctx);
24
- auto b_ih = args[4 ].ITensorOrFreeze (ctx);
25
- auto b_hh = args[5 ].ITensorOrFreeze (ctx);
26
47
27
48
LOG_DEBUG (" Input tensor shape: " << input->getDimensions ());
28
49
LOG_DEBUG (" w_ih tensor shape: " << w_ih->getDimensions ());
29
50
LOG_DEBUG (" w_hh tensor shape: " << w_hh->getDimensions ());
30
- LOG_DEBUG (" b_ih tensor shape: " << b_ih->getDimensions ());
31
- LOG_DEBUG (" b_hh tensor shape: " << b_hh->getDimensions ());
32
-
51
+
33
52
std::vector<nvinfer1::ITensor*> state;
34
53
auto hx = args[1 ].IValue ()->toListRef ();
35
54
for (unsigned int i = 0 ; i < hx.size (); i++) {
@@ -51,81 +70,56 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
51
70
// calculate first half of gates
52
71
auto mm1 = ctx->net ->addMatrixMultiply (*input, nvinfer1::MatrixOperation::kNONE , *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE );
53
72
TRTORCH_CHECK (mm1, " Unable to create matrix multiplication node: " << *n);
54
-
55
73
auto mm1_out = mm1->getOutput (0 );
56
- auto mm1_dim = mm1_out->getDimensions ();
57
- auto b_ih_dim = b_ih->getDimensions ();
58
-
59
- TRTORCH_CHECK (util::broadcastable (mm1_dim, b_ih_dim, false ));
60
74
61
- if (util::toVec (mm1_dim) != util::toVec (b_ih_dim)) {
62
- LOG_DEBUG (" b_ih dimensions need to be reshaped" );
63
-
64
- auto shuffle = ctx->net ->addShuffle (*b_ih);
65
- TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
66
- shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_ih_dim), mm1_dim.nbDims ));
67
- b_ih = shuffle->getOutput (0 );
68
- }
69
-
70
- auto add1 = ctx->net ->addElementWise (*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
71
- TRTORCH_CHECK (add1, " Unable to create ElementWise layer from node: " << *n);
72
- auto add1_out = add2->getOutput (0 );
75
+ auto out1 = !args[4 ].IValue ()->isNone () ? add_bias (mm1_out, args[4 ].ITensorOrFreeze (ctx), " b_ih" , ctx, n) : mm1_out;
73
76
74
77
// calculate second half of gates
75
- auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOE );
78
+ auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE );
76
79
TRTORCH_CHECK (mm2, " Unable to create matrix multiplication node: " << *n);
77
-
78
80
auto mm2_out = mm2->getOutput (0 );
79
- auto mm2_dim = mm2_out->getDimensions ();
80
- auto b_hh_dim = b_hh->getDimensions ();
81
-
82
- TRTORCH_CHECK (util::broadcastable (mm2_dim, b_hh_dim, false ));
83
81
84
- if (util::toVec (mm2_dim) != util::toVec (b_hh_dim)) {
85
- LOG_DEBUG (" b_hh dimensions need to be reshaped" );
86
-
87
- auto shuffle = ctx->net ->addShuffle (*b_hh);
88
- TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
89
- shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_hh_dim), mm2_dim.nbDims ));
90
- b_hh = shuffle->getOutput (0 );
91
- }
92
-
93
- auto add2 = ctx->net ->addElementWise (*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
94
- TRTORCH_CHECK (add2, " Unable to create ElementWise layer from node: " << *n);
95
- auto add2_out = add2->getOutput (0 );
82
+ auto out2 = !args[5 ].IValue ()->isNone () ? add_bias (mm2_out, args[5 ].ITensorOrFreeze (ctx), " b_hh" , ctx, n) : mm2_out;
96
83
97
84
// gates
98
- auto add3 = ctx->net ->addElementWise (*add1_out , *add2_out , nvinfer1::ElementWiseOperation::kSUM );
99
- TRTORCH_CHECK (add3 , " Unable to create ElementWise layer from node: " << *n);
100
- auto add3_out = add3 ->getOutput (0 );
85
+ auto add = ctx->net ->addElementWise (*out1 , *out2 , nvinfer1::ElementWiseOperation::kSUM );
86
+ TRTORCH_CHECK (add , " Unable to create ElementWise layer from node: " << *n);
87
+ auto add_out = add ->getOutput (0 );
101
88
102
89
// chunk Tensor into 4 parts and apply activation functions
103
- auto dims = util::toVec (add3_out ->getDimensions ());
90
+ auto dims = util::toVec (add_out ->getDimensions ());
104
91
auto batch = dims[0 ];
105
92
auto hidden = dims[1 ]/4 ;
106
93
107
- auto size = util::toDims (std::vector<int64_t >({batch, hidden}));
108
- auto stride = util::toDims (std::vector<int64_t >({1 , 1 }));
94
+ std::vector<int64_t > size_vec = {batch, hidden};
95
+ std::vector<int64_t > stride_vec = {1 , 1 };
96
+ std::vector<int64_t > offset0 = {0 , 0 };
97
+ std::vector<int64_t > offset1 = {0 , hidden};
98
+ std::vector<int64_t > offset2 = {0 , 2 *hidden};
99
+ std::vector<int64_t > offset3 = {0 , 3 *hidden};
100
+
101
+ auto size = util::toDims (size_vec);
102
+ auto stride = util::toDims (stride_vec);
109
103
110
- auto slice1 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 0 }) ), size, stride);
104
+ auto slice1 = ctx->net ->addSlice (*add_out , util::toDims (offset0 ), size, stride);
111
105
TRTORCH_CHECK (slice1, " Unable to create Slice layer from node: " << *n);
112
106
auto activ1 = ctx->net ->addActivation (*slice1->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
113
107
TRTORCH_CHECK (activ1, " Unable to create sigmoid activation layer from node: " << *n);
114
108
auto ingate = activ1->getOutput (0 );
115
109
116
- auto slice2 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , hidden}) ), size, stride);
110
+ auto slice2 = ctx->net ->addSlice (*add_out , util::toDims (offset1 ), size, stride);
117
111
TRTORCH_CHECK (slice2, " Unable to create Slice layer from node: " << *n);
118
112
auto activ2 = ctx->net ->addActivation (*slice2->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
119
113
TRTORCH_CHECK (activ2, " Unable to create sigmoid activation layer from node: " << *n);
120
114
auto forgetgate = activ2->getOutput (0 );
121
115
122
- auto slice3 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 2 *hidden}) ), size, stride);
116
+ auto slice3 = ctx->net ->addSlice (*add_out , util::toDims (offset2 ), size, stride);
123
117
TRTORCH_CHECK (slice3, " Unable to create Slice layer from node: " << *n);
124
118
auto activ3 = ctx->net ->addActivation (*slice3->getOutput (0 ), nvinfer1::ActivationType::kTANH );
125
119
TRTORCH_CHECK (activ3, " Unable to create tanh activation layer from node: " << *n);
126
120
auto cellgate = activ3->getOutput (0 );
127
121
128
- auto slice4 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 3 *hidden}) ), size, stride);
122
+ auto slice4 = ctx->net ->addSlice (*add_out , util::toDims (offset3 ), size, stride);
129
123
TRTORCH_CHECK (slice4, " Unable to create Slice layer from node: " << *n);
130
124
auto activ4 = ctx->net ->addActivation (*slice4->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
131
125
TRTORCH_CHECK (activ4, " Unable to create sigmoid activation layer from node: " << *n);
0 commit comments