@@ -769,6 +769,33 @@ def forward(self, x):
769
769
lstm = LSTM (features , hidden , batch , bidirectional = True )
770
770
save_data_and_model ("lstm_bidirectional" , input , lstm )
771
771
772
+
773
+
774
+ class HiddenLSTM (nn .Module ):
775
+ def __init__ (self , input_size , hidden_size , num_layers = 1 , is_bidirectional = False ):
776
+ super ().__init__ ()
777
+ self .hidden_size = hidden_size
778
+ self .num_layers = num_layers
779
+ self .bi_coeff = 2 if is_bidirectional else 1
780
+ self .lstm = nn .LSTM (input_size = input_size , hidden_size = hidden_size ,
781
+ num_layers = num_layers , bidirectional = is_bidirectional )
782
+
783
+ def forward (self , t ):
784
+ h_0 = torch .ones (self .num_layers * self .bi_coeff , t .size (1 ),
785
+ self .hidden_size )
786
+ c_0 = torch .ones (self .num_layers * self .bi_coeff , t .size (1 ),
787
+ self .hidden_size )
788
+ return self .lstm (t , (h_0 , c_0 ))[0 ]
789
+
790
+ input = torch .randn (seq_len , batch , features )
791
+ hidden_lstm = HiddenLSTM (features , hidden , num_layers = 3 , is_bidirectional = False )
792
+ save_data_and_model ("hidden_lstm" , input , hidden_lstm , version = 11 , export_params = True )
793
+
794
+ input = torch .randn (seq_len , batch , features )
795
+ hidden_lstm = HiddenLSTM (features , hidden , num_layers = 3 , is_bidirectional = True )
796
+ save_data_and_model ("hidden_lstm_bi" , input , hidden_lstm , version = 11 , export_params = True )
797
+
798
+
772
799
class MatMul (nn .Module ):
773
800
def __init__ (self ):
774
801
super (MatMul , self ).__init__ ()
0 commit comments