@@ -168,6 +168,11 @@ def _common(cls, node, **kwargs):
168
168
hidden_size = node .attrs ["hidden_size" ]
169
169
direction = node .attrs .get ("direction" , "forward" )
170
170
num_directions = 2 if direction == "bidirectional" else 1
171
+ layout = node .attrs .get ("layout" , 0 )
172
+
173
+ # Need transpose for batchwise
174
+ if layout == 1 :
175
+ x = tf .transpose (x , perm = [1 , 0 , 2 ])
171
176
172
177
# removed from version 7, default is 0
173
178
output_sequence = node .attrs .get ("output_sequence" , 0 )
@@ -220,6 +225,10 @@ def _common(cls, node, **kwargs):
220
225
initial_c = tensor_dict .get (
221
226
node .inputs [6 ],
222
227
None ) if input_size >= 7 else tf .zeros_like (initial_h )
228
+ # Need transpose for batchwise
229
+ if layout == 1 :
230
+ initial_h = tf .transpose (initial_h , perm = [1 , 0 , 2 ])
231
+ initial_c = tf .transpose (initial_c , perm = [1 , 0 , 2 ])
223
232
if initial_h is not None and initial_c is not None :
224
233
initial_state = (tf .compat .v1 .nn .rnn_cell .LSTMStateTuple (
225
234
initial_c [0 ], initial_h [0 ]),)
@@ -261,6 +270,12 @@ def _common(cls, node, **kwargs):
261
270
output_bw = tf .expand_dims (output_bw , 1 )
262
271
output = tf .concat ((output_fw , output_bw ), axis = 1 )
263
272
273
+ # Need transpose for batchwise
274
+ if layout == 1 :
275
+ output = tf .transpose (output , perm = [2 , 0 , 1 , 3 ])
276
+ h = tf .transpose (h , perm = [1 , 0 , 2 ])
277
+ c = tf .transpose (c , perm = [1 , 0 , 2 ])
278
+
264
279
return [output , h , c ] if output_sequence == 0 else [h , c ]
265
280
266
281
@classmethod
@@ -270,3 +285,7 @@ def version_1(cls, node, **kwargs):
270
285
@classmethod
271
286
def version_7 (cls , node , ** kwargs ):
272
287
return cls ._common (node , ** kwargs )
288
+
289
+ @classmethod
290
+ def version_14 (cls , node , ** kwargs ):
291
+ return cls ._common (node , ** kwargs )
0 commit comments