Skip to content

Commit 87f2521

Browse files
authored
Merge pull request #886 from JulieBar:lstm_inside
2 parents 76f5ea2 + 7e94265 commit 87f2521

14 files changed

+27
-0
lines changed

testdata/dnn/layers/lstm.hidden.B.npy

176 Bytes
Binary file not shown.

testdata/dnn/layers/lstm.hidden.R.npy

224 Bytes
Binary file not shown.

testdata/dnn/layers/lstm.hidden.W.npy

272 Bytes
Binary file not shown.
188 Bytes
Binary file not shown.
188 Bytes
Binary file not shown.
288 Bytes
Binary file not shown.
248 Bytes
Binary file not shown.
288 Bytes
Binary file not shown.
288 Bytes
Binary file not shown.
248 Bytes
Binary file not shown.
368 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,33 @@ def forward(self, x):
769769
lstm = LSTM(features, hidden, batch, bidirectional=True)
770770
save_data_and_model("lstm_bidirectional", input, lstm)
771771

772+
773+
774+
class HiddenLSTM(nn.Module):
775+
def __init__(self, input_size, hidden_size, num_layers=1, is_bidirectional=False):
776+
super().__init__()
777+
self.hidden_size = hidden_size
778+
self.num_layers = num_layers
779+
self.bi_coeff = 2 if is_bidirectional else 1
780+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
781+
num_layers=num_layers, bidirectional=is_bidirectional)
782+
783+
def forward(self, t):
784+
h_0 = torch.ones(self.num_layers * self.bi_coeff, t.size(1),
785+
self.hidden_size)
786+
c_0 = torch.ones(self.num_layers * self.bi_coeff, t.size(1),
787+
self.hidden_size)
788+
return self.lstm(t, (h_0, c_0))[0]
789+
790+
input = torch.randn(seq_len, batch, features)
791+
hidden_lstm = HiddenLSTM(features, hidden, num_layers=3, is_bidirectional=False)
792+
save_data_and_model("hidden_lstm", input, hidden_lstm, version=11, export_params=True)
793+
794+
input = torch.randn(seq_len, batch, features)
795+
hidden_lstm = HiddenLSTM(features, hidden, num_layers=3, is_bidirectional=True)
796+
save_data_and_model("hidden_lstm_bi", input, hidden_lstm, version=11, export_params=True)
797+
798+
772799
class MatMul(nn.Module):
773800
def __init__(self):
774801
super(MatMul, self).__init__()
3.72 KB
Binary file not shown.
5.93 KB
Binary file not shown.

0 commit comments

Comments
 (0)