Skip to content

Commit bba676f

Browse files
authored
Merge pull request #961 from chinhuang007/support-gru-14
Add batchwise support for gru, lstm, rnn
2 parents b13a282 + 8d304f5 commit bba676f

File tree

5 files changed

+61
-8
lines changed

5 files changed

+61
-8
lines changed

doc/support_status.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ONNX-Tensorflow Support Status
22
|||
33
|-:|:-|
4-
|ONNX-Tensorflow Version|Master ( commit id: 9ab9b934c2c8494b6309d20f15acabcb3abd126d )|
4+
|ONNX-Tensorflow Version|Master ( commit id: a1005fbd2a95699a34f83d5d25fe20d5213860d3 )|
55
|ONNX Version|Master ( commit id: 1f63dcb7fcc3a8bf5c3c8e326867ecd6f5c43f35 )|
66
|Tensorflow Version|v2.5.0|
77

@@ -62,7 +62,7 @@ Notes:
6262
|EyeLike|-|-|-|-|-|-|-|-|**9**|9|9|9|9|9|9|EyeLike|
6363
|Flatten|**1**|1|1|1|1|1|1|1|**9**|9|**11**|11|**13**|13|13|Flatten|
6464
|Floor|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|13|13|Floor|
65-
|GRU|**1**:small_orange_diamond:|1:small_orange_diamond:|**3**:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_red_triangle:|14:small_red_triangle:|GRU|
65+
|GRU|**1**:small_orange_diamond:|1:small_orange_diamond:|**3**:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|3:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_orange_diamond:|14:small_orange_diamond:|GRU|
6666
|Gather|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|13|13|Gather|
6767
|GatherElements|-|-|-|-|-|-|-|-|-|-|**11**|11|**13**|13|13|GatherElements|
6868
|GatherND|-|-|-|-|-|-|-|-|-|-|**11**|**12**|**13**|13|13|GatherND|
@@ -81,7 +81,7 @@ Notes:
8181
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|10|10|10|IsInf|
8282
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**|13|13|IsNaN|
8383
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|13|13|LRN|
84-
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_red_triangle:|14:small_red_triangle:|LSTM|
84+
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_orange_diamond:|14:small_orange_diamond:|LSTM|
8585
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|6|6|LeakyRelu|
8686
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|13|13|Less|
8787
|LessOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|12|12|LessOrEqual|
@@ -118,7 +118,7 @@ Notes:
118118
|QLinearConv|-|-|-|-|-|-|-|-|-|**10**|10|10|10|10|10|QLinearConv|
119119
|QLinearMatMul|-|-|-|-|-|-|-|-|-|**10**|10|10|10|10|10|QLinearMatMul|
120120
|QuantizeLinear|-|-|-|-|-|-|-|-|-|**10**|10|10|**13**|13|13|QuantizeLinear|
121-
|RNN|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_red_triangle:|14:small_red_triangle:|RNN|
121+
|RNN|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|**14**:small_orange_diamond:|14:small_orange_diamond:|RNN|
122122
|RandomNormal|**1**|1|1|1|1|1|1|1|1|1|1|1|1|1|1|RandomNormal|
123123
|RandomNormalLike|**1**|1|1|1|1|1|1|1|1|1|1|1|1|1|1|RandomNormalLike|
124124
|RandomUniform|**1**|1|1|1|1|1|1|1|1|1|1|1|1|1|1|RandomUniform|
@@ -186,7 +186,7 @@ Notes:
186186
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|9|9|Where|
187187
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|7|7|Xor|
188188

189-
ONNX-TF Supported Operators / ONNX Operators: 151 / 169
189+
ONNX-TF Supported Operators / ONNX Operators: 154 / 169
190190

191191
Notes:
192192
1. BatchNormalization: BatchNormalization with training_mode=1 is not supported in Tensorflow converte.

onnx_tf/handlers/backend/gru.py

+17
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def _common(cls, node, **kwargs):
161161
hidden_size = node.attrs["hidden_size"]
162162
direction = node.attrs.get("direction", "forward")
163163
num_directions = 2 if direction == "bidirectional" else 1
164+
layout = node.attrs.get("layout", 0)
165+
166+
# Need transpose for batchwise
167+
if layout == 1:
168+
x = tf.transpose(x, perm=[1, 0, 2])
164169

165170
# removed from version 7, default is 0
166171
output_sequence = node.attrs.get("output_sequence", 0)
@@ -207,6 +212,9 @@ def _common(cls, node, **kwargs):
207212
if input_size == 6:
208213
initial_h = tensor_dict.get(node.inputs[5], None)
209214
if initial_h is not None:
215+
# Need transpose for batchwise
216+
if layout == 1:
217+
initial_h = tf.transpose(initial_h, perm=[1, 0, 2])
210218
initial_state = (initial_h[0],)
211219
if num_directions == 2:
212220
initial_state_bw = (initial_h[1],)
@@ -241,6 +249,11 @@ def _common(cls, node, **kwargs):
241249
output_bw = tf.expand_dims(output_bw, 1)
242250
output = tf.concat((output_fw, output_bw), axis=1)
243251

252+
# Need transpose for batchwise
253+
if layout == 1:
254+
output = tf.transpose(output, perm=[2, 0, 1, 3])
255+
h = tf.transpose(h, perm=[1, 0, 2])
256+
244257
return [output, h] if output_sequence == 0 else [h]
245258

246259
@classmethod
@@ -254,3 +267,7 @@ def version_3(cls, node, **kwargs):
254267
@classmethod
255268
def version_7(cls, node, **kwargs):
256269
return cls._common(node, **kwargs)
270+
271+
@classmethod
272+
def version_14(cls, node, **kwargs):
273+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/lstm.py

+19
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def _common(cls, node, **kwargs):
168168
hidden_size = node.attrs["hidden_size"]
169169
direction = node.attrs.get("direction", "forward")
170170
num_directions = 2 if direction == "bidirectional" else 1
171+
layout = node.attrs.get("layout", 0)
172+
173+
# Need transpose for batchwise
174+
if layout == 1:
175+
x = tf.transpose(x, perm=[1, 0, 2])
171176

172177
# removed from version 7, default is 0
173178
output_sequence = node.attrs.get("output_sequence", 0)
@@ -220,6 +225,10 @@ def _common(cls, node, **kwargs):
220225
initial_c = tensor_dict.get(
221226
node.inputs[6],
222227
None) if input_size >= 7 else tf.zeros_like(initial_h)
228+
# Need transpose for batchwise
229+
if layout == 1:
230+
initial_h = tf.transpose(initial_h, perm=[1, 0, 2])
231+
initial_c = tf.transpose(initial_c, perm=[1, 0, 2])
223232
if initial_h is not None and initial_c is not None:
224233
initial_state = (tf.compat.v1.nn.rnn_cell.LSTMStateTuple(
225234
initial_c[0], initial_h[0]),)
@@ -261,6 +270,12 @@ def _common(cls, node, **kwargs):
261270
output_bw = tf.expand_dims(output_bw, 1)
262271
output = tf.concat((output_fw, output_bw), axis=1)
263272

273+
# Need transpose for batchwise
274+
if layout == 1:
275+
output = tf.transpose(output, perm=[2, 0, 1, 3])
276+
h = tf.transpose(h, perm=[1, 0, 2])
277+
c = tf.transpose(c, perm=[1, 0, 2])
278+
264279
return [output, h, c] if output_sequence == 0 else [h, c]
265280

266281
@classmethod
@@ -270,3 +285,7 @@ def version_1(cls, node, **kwargs):
270285
@classmethod
271286
def version_7(cls, node, **kwargs):
272287
return cls._common(node, **kwargs)
288+
289+
@classmethod
290+
def version_14(cls, node, **kwargs):
291+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/rnn.py

+17
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ def _common(cls, node, **kwargs):
104104
direction = node.attrs.get("direction", "forward")
105105
num_directions = 2 if direction == "bidirectional" else 1
106106
output_sequence = node.attrs.get("output_sequence", 0)
107+
layout = node.attrs.get("layout", 0)
108+
109+
# Need transpose for batchwise
110+
if layout == 1:
111+
x = tf.transpose(x, perm=[1, 0, 2])
107112

108113
# TODO(fumihwh): check if prev node is one of RNN
109114
# process input if it comes from other previous cell
@@ -145,6 +150,9 @@ def _common(cls, node, **kwargs):
145150
if input_size == 6:
146151
initial_h = tensor_dict.get(node.inputs[5], None)
147152
if initial_h is not None:
153+
# Need transpose for batchwise
154+
if layout == 1:
155+
initial_h = tf.transpose(initial_h, perm=[1, 0, 2])
148156
initial_state = (initial_h[0],)
149157
if num_directions == 2:
150158
initial_state_bw = (initial_h[1],)
@@ -179,6 +187,11 @@ def _common(cls, node, **kwargs):
179187
output_bw = tf.expand_dims(output_bw, 1)
180188
output = tf.concat((output_fw, output_bw), axis=1)
181189

190+
# Need transpose for batchwise
191+
if layout == 1:
192+
output = tf.transpose(output, perm=[2, 0, 1, 3])
193+
h = tf.transpose(h, perm=[1, 0, 2])
194+
182195
return [output, h] if output_sequence == 0 else [h]
183196

184197
@classmethod
@@ -188,3 +201,7 @@ def version_1(cls, node, **kwargs):
188201
@classmethod
189202
def version_7(cls, node, **kwargs):
190203
return cls._common(node, **kwargs)
204+
205+
@classmethod
206+
def version_14(cls, node, **kwargs):
207+
return cls._common(node, **kwargs)

onnx_tf/opset_version.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
'FeatureVectorizer': [],
5555
'Flatten': [1, 9, 11, 13],
5656
'Floor': [1, 6, 13],
57-
'GRU': [1, 3, 7],
57+
'GRU': [1, 3, 7, 14],
5858
'Gather': [1, 11, 13],
5959
'GatherElements': [11, 13],
6060
'GatherND': [11, 12, 13],
@@ -76,7 +76,7 @@
7676
'IsInf': [10],
7777
'IsNaN': [9, 13],
7878
'LRN': [1, 13],
79-
'LSTM': [1, 7],
79+
'LSTM': [1, 7, 14],
8080
'LabelEncoder': [],
8181
'LeakyRelu': [1, 6],
8282
'Less': [1, 7, 9, 13],
@@ -119,7 +119,7 @@
119119
'QLinearConv': [10],
120120
'QLinearMatMul': [10],
121121
'QuantizeLinear': [10, 13],
122-
'RNN': [1, 7],
122+
'RNN': [1, 7, 14],
123123
'RandomNormal': [1],
124124
'RandomNormalLike': [1],
125125
'RandomUniform': [1],

0 commit comments

Comments
 (0)