Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Commit 9316ad6

Browse files
authored
Better conversion for the subclassing model and code reformat. (#446)
* Better conversion for the subclassed model and code reformat. * Update the tutorial * fixing anonying layer name reusing issue.
1 parent c57c2e7 commit 9316ad6

11 files changed

+196
-247
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@ dist
3636
# test generated files
3737
.pytest_cache/
3838
.cache
39+
.ipynb_checkpoints/
3940
htmlcov
4041
coverage.xml
4142
.coverage
4243
docs/auto_examples/*
4344
docs/examples/imagenet*.json
4445
docs/examples/*.onnx
4546
docs/examples/tiny_yolov2*
47+
tutorial/*.onnx
48+
tutorial/.ipynb_checkpoints/*

keras2onnx/_parse_tf.py

+38-30
Original file line numberDiff line numberDiff line change
@@ -175,33 +175,25 @@ def _layer_name_dict(tf_utils, layer, prefix, parent=None):
175175
return output_dict
176176

177177

178-
def _to_tf_ops(layer_name, graph, fstr):
178+
def _to_tf_ops(layer_name, fstr, ops_table):
179179
ops = []
180180
op_name = fstr.format(layer_name) if fstr is not None else None
181181
if op_name is None:
182182
return ops
183183

184-
try:
185-
ops[0:] = [graph.get_operation_by_name(op_name)]
186-
idx = 1
187-
if not re.match(r".+_\d+", layer_name): # if layer name already numbered, skipped then.
188-
while True: # break out by exception.
189-
op_name = fstr.format("%s_%d" % (layer_name, idx))
190-
ops[idx:] = [graph.get_operation_by_name(op_name)]
191-
idx += 1
192-
except KeyError:
193-
pass
194-
195-
return ops
196-
184+
if re.match(r".+_\d+$", layer_name): # if layer name already numbered, skipped then.
185+
return ops
197186

198-
def _advance_output_node_if_successor(graph, layer, output):
199-
for op_ in graph.get_operations():
200-
if op_.name.find(layer) == 0:
201-
if output in [ts_.op.name for ts_ in op_.inputs]:
202-
return op_.name
187+
idx = 0
188+
while True:
189+
op_name = fstr.format("%s_%d" % (layer_name, idx + 1))
190+
if op_name in ops_table:
191+
ops.append(ops_table[op_name])
192+
else:
193+
break
194+
idx += 1
203195

204-
return output
196+
return ops
205197

206198

207199
def build_layer_outputs(model, graph, outputs):
@@ -211,23 +203,39 @@ def build_layer_outputs(model, graph, outputs):
211203
output_dict = {}
212204
layer_dict = _layer_name_dict(tf_utils, model, model.name)
213205

206+
ops_table = {op_.name: op_ for op_ in graph.get_operations()}
207+
208+
def add_output_node(graph, op, fx_list, layer_name):
209+
output_node_name = op.name
210+
if len(fx_list) > 1: # if there is no output node function.
211+
# fx_[1] is output node redirect function.
212+
output_node_name = fx_list[1](lobj, op)
213+
assert graph.get_operation_by_name(output_node_name) is not None, "Parsing layer({}) failed.".format(lobj)
214+
if output_node_name not in output_dict: # if there is already a same kind of layer, not overwrite it.
215+
output_dict[output_node_name] = layer_dict[layer_name]
216+
217+
for ln_, layer_info_ in layer_dict.items():
218+
lobj = layer_info_[0]
219+
fstr_list, fx_list = keras_layer_spec(type(lobj))
220+
if fstr_list is None:
221+
continue
222+
223+
for fstr_ in fstr_list:
224+
op_name = fstr_.format(ln_)
225+
if op_name not in ops_table:
226+
continue
227+
add_output_node(graph, ops_table[op_name], fx_list, ln_)
228+
229+
# now process the case when a layer was re-used several times in one model.
214230
for ln_, layer_info_ in layer_dict.items():
215231
lobj = layer_info_[0]
216232
fstr_list, fx_list = keras_layer_spec(type(lobj))
217233
if fstr_list is None:
218234
continue
219235

220236
for fstr_ in fstr_list:
221-
for op_ in _to_tf_ops(ln_, graph, fstr_):
222-
if len(fx_list) <= 1:
223-
output_dict[op_.name] = layer_dict[ln_]
224-
else:
225-
# fx_[1] is output node redirect function.
226-
output_node = fx_list[1](lobj, op_)
227-
output_node = _advance_output_node_if_successor(graph, ln_, output_node)
228-
assert graph.get_operation_by_name(output_node) is not None, "Parsing layer({}) failed.".format(
229-
lobj)
230-
output_dict[output_node] = layer_dict[ln_]
237+
for op_ in _to_tf_ops(ln_, fstr_, ops_table):
238+
add_output_node(graph, op_, fx_list, ln_)
231239

232240
return output_dict
233241

keras2onnx/ke2onnx/layer_spec.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,7 @@ def _simple_layer_name_extractor(fstr_list, node_name):
2727

2828

2929
def _conv_layer_spec_outputs(layer, node):
30-
if type(layer) == _layer.DepthwiseConv2D:
31-
if not layer.use_bias:
32-
return node.name
33-
else:
34-
ri = node.name.rindex('/')
35-
return node.name[:ri + 1] + 'BiasAdd'
36-
elif type(layer) == _layer.Conv1D:
30+
if type(layer) == _layer.Conv1D:
3731
return node.name + '/Squeeze'
3832

3933
activation_map = {
@@ -42,15 +36,26 @@ def _conv_layer_spec_outputs(layer, node):
4236
tf.nn.softmax: 'Softmax',
4337
tf.nn.relu: 'Relu',
4438
tf.nn.elu: 'Elu',
45-
tf.nn.tanh: 'Tanh'}
39+
tf.nn.tanh: 'Tanh',
40+
tf.nn.swish: 'mul'}
4641

4742
node_act = activation_map.get(layer.activation, None)
43+
if node_act is None:
44+
actname_map = {a_.__name__: a_ for a_ in activation_map}
45+
act_trans = actname_map.get(layer.activation.__name__, None)
46+
if act_trans is not None:
47+
node_act = activation_map.get(act_trans)
48+
4849
assert node_act is not None, "Unsupported activation in the layer({})".format(layer.activation)
4950
if node_act:
5051
ri = node.name.rindex('/')
5152
return node.name[:ri + 1] + node_act
5253
else:
53-
return node.name
54+
if not layer.use_bias:
55+
return node.name
56+
else:
57+
ri = node.name.rindex('/')
58+
return node.name[:ri + 1] + 'BiasAdd'
5459

5560

5661
def _relu_like_spec_outputs(layer, node):
@@ -69,11 +74,17 @@ def _relu_like_spec_outputs(layer, node):
6974
_layer.MaxPooling1D: (["{}/MaxPool"], [_default_layer_name_extractor]),
7075
_layer.MaxPooling2D: (["{}/MaxPool"], [_default_layer_name_extractor]),
7176
_layer.MaxPooling3D: (["{}/MaxPool"], [_default_layer_name_extractor]),
77+
7278
_layer.Conv1D: (["{}/conv1d"], [_simple_layer_name_extractor, _conv_layer_spec_outputs]),
79+
_layer.Conv2D: (["{}/Conv2D"], [_simple_layer_name_extractor, _conv_layer_spec_outputs]),
80+
7381
_layer.Conv2DTranspose: (["{}/conv2d_transpose"], [_simple_layer_name_extractor, _conv_layer_spec_outputs]),
7482
_layer.DepthwiseConv2D: (["{}/depthwise"], [_simple_layer_name_extractor, _conv_layer_spec_outputs]),
83+
7584
_layer.LeakyReLU: (["{}/LeakyRelu"], [_default_layer_name_extractor]),
76-
_adv_activations.PReLU: (["{}/Relu"], [_simple_layer_name_extractor, _relu_like_spec_outputs])
85+
_adv_activations.PReLU: (["{}/Relu"], [_simple_layer_name_extractor, _relu_like_spec_outputs]),
86+
87+
_layer.Reshape: (["{}/Reshape"], [_default_layer_name_extractor])
7788
}
7889

7990
if not is_keras_older_than('2.2.0'):

keras2onnx/topology.py

+32-31
Original file line numberDiff line numberDiff line change
@@ -299,37 +299,38 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir
299299
# enable the ONNX optimizations
300300
graph = None
301301
nodes = container.nodes
302-
try:
303-
import onnxconverter_common
304-
origin_node_number = len(container.nodes)
305-
if target_opset < 9:
306-
nodes = onnxconverter_common.optimizer.optimize_onnx(nodes, nchw_inputs=nchw_inputs,
307-
inputs=container.inputs + extra_inputs,
308-
outputs=container.outputs)
309-
node_number = len(nodes)
310-
else:
311-
graph = onnxconverter_common.optimizer.optimize_onnx_graph(nodes, nchw_inputs=nchw_inputs,
312-
inputs=container.inputs,
313-
outputs=container.outputs,
314-
initializers=container.initializers,
315-
model_value_info=container.value_info,
316-
model_name=model_name,
317-
target_opset=container.target_opset)
318-
node_number = len(graph.node)
319-
k2o_logger().info("The node number after optimization: {} -> {}".format(origin_node_number, node_number))
320-
except ImportError:
321-
onnx_not_imported = 'onnxconverter_common is not imported,'
322-
if nchw_inputs:
323-
raise Exception(
324-
'{} nchw_inputs does not make effect. Please set nchw_inputs to empty.'.format(onnx_not_imported))
325-
k2o_logger().warning('{} so the convertor optimizer is not enabled.'.format(onnx_not_imported))
326-
except Exception as e:
327-
# either optimizer issue or converter issue, we just let it go to diagnose the issue from the converted model.
328-
k2o_logger().warning('There is an error({}) happened during optimizing on the converted model!'.format(type(e)))
329-
k2o_logger().warning(str(e))
330-
import traceback
331-
tb = traceback.format_exc()
332-
k2o_logger().warning(tb)
302+
if not topology.debug_mode:
303+
try:
304+
import onnxconverter_common
305+
origin_node_number = len(container.nodes)
306+
if target_opset < 9:
307+
nodes = onnxconverter_common.optimizer.optimize_onnx(nodes, nchw_inputs=nchw_inputs,
308+
inputs=container.inputs + extra_inputs,
309+
outputs=container.outputs)
310+
node_number = len(nodes)
311+
else:
312+
graph = onnxconverter_common.optimizer.optimize_onnx_graph(nodes, nchw_inputs=nchw_inputs,
313+
inputs=container.inputs,
314+
outputs=container.outputs,
315+
initializers=container.initializers,
316+
model_value_info=container.value_info,
317+
model_name=model_name,
318+
target_opset=container.target_opset)
319+
node_number = len(graph.node)
320+
k2o_logger().info("The node number after optimization: {} -> {}".format(origin_node_number, node_number))
321+
except ImportError:
322+
onnx_not_imported = 'onnxconverter_common is not imported,'
323+
if nchw_inputs:
324+
raise Exception(
325+
'{} nchw_inputs does not make effect. Please set nchw_inputs to empty.'.format(onnx_not_imported))
326+
k2o_logger().warning('{} so the convertor optimizer is not enabled.'.format(onnx_not_imported))
327+
except Exception as e: # noqa
328+
# either optimizer issue or converter issue, we just let it go to diagnose the issue from the converted model.
329+
k2o_logger().warning('There is an error({}) happened during optimizing on the converted model!'.format(type(e)))
330+
k2o_logger().warning(str(e))
331+
import traceback
332+
tb = traceback.format_exc()
333+
k2o_logger().warning(tb)
333334

334335
if graph is None:
335336
# Create a graph from its main components

tests/data/panda.jpg

181 KB
Loading

tests/test_cgan.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import keras2onnx
99
import numpy as np
1010
from keras2onnx.proto import keras, is_tf_keras
11-
from test_utils import run_onnx_runtime
1211
from distutils.version import StrictVersion
1312

1413
Activation = keras.layers.Activation
1514
BatchNormalization = keras.layers.BatchNormalization
16-
Conv2D = keras.layers.Conv2D
1715
Dense = keras.layers.Dense
1816
Dropout = keras.layers.Dropout
1917
Embedding = keras.layers.Embedding
@@ -121,7 +119,7 @@ def build_discriminator(self):
121119

122120
@pytest.mark.skipif(keras2onnx.proto.tfcompat.is_tf2 and is_tf_keras, reason="Tensorflow 1.x only tests.")
123121
@pytest.mark.skipif(is_tf_keras and StrictVersion(tf.__version__.split('-')[0]) < StrictVersion("1.14.0"),
124-
reason="Not supported before tensorflow 1.14.0 for tf_keras")
122+
reason="Not supported before tensorflow 1.14.0 for tf_keras")
125123
def test_CGAN(runner):
126124
keras_model = CGAN().combined
127125
batch = 5

tests/test_layers.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
RNN_CLASSES = [SimpleRNN, GRU, LSTM]
7373

7474

75-
def asarray(*a):
75+
def _asarray(*a):
7676
return np.array([a], dtype='f')
7777

7878

@@ -758,7 +758,7 @@ def test_dense(runner):
758758
model.compile('sgd', 'mse')
759759
onnx_model = keras2onnx.convert_keras(model, model.name)
760760

761-
data = asarray(1, 0, 0, 1)
761+
data = _asarray(1, 0, 0, 1)
762762
expected = model.predict(data)
763763
assert runner('dense', onnx_model, data, expected)
764764

@@ -775,7 +775,7 @@ def test_dense_add(runner):
775775
model.compile('sgd', 'mse')
776776
onnx_model = keras2onnx.convert_keras(model, model.name)
777777

778-
data = [asarray(1.2, 2.4, -2, 1), asarray(-1, -2, 0, 1, 2), asarray(0.5, 1.5, -3.14159)]
778+
data = [_asarray(1.2, 2.4, -2, 1), _asarray(-1, -2, 0, 1, 2), _asarray(0.5, 1.5, -3.14159)]
779779
expected = model.predict(data)
780780
assert runner('onnx_dense_add', onnx_model, data, expected)
781781

@@ -796,7 +796,7 @@ def test_conv_add(runner):
796796

797797

798798
def test_dense_softmax(runner):
799-
data = asarray(1, 2, 3, 4)
799+
data = _asarray(1, 2, 3, 4)
800800
model = Sequential()
801801
model.add(Dense(5, input_shape=(4,), activation='softmax'))
802802
model.add(Dense(3, input_shape=(5,), use_bias=True))
@@ -831,7 +831,7 @@ def test_dense_softmax(runner):
831831
(lambda: Concatenate(2), ([[1, 2], [3, 4]], [[4, 5], [6, 7]])),
832832
])
833833
def test_merge_layer(runner, layer_type, data):
834-
data2 = [asarray(*d) for d in data]
834+
data2 = [_asarray(*d) for d in data]
835835
inputs = [Input(shape=d.shape[1:]) for d in data2]
836836
layer = layer_type()(inputs)
837837
model = keras.models.Model(inputs=inputs, outputs=layer)
@@ -1043,7 +1043,7 @@ def test_repeat_vector(runner):
10431043
model.add(keras.layers.core.RepeatVector(3, input_shape=(4,)))
10441044
onnx_model = keras2onnx.convert_keras(model, model.name)
10451045

1046-
data = asarray(1, 2, 3, 4)
1046+
data = _asarray(1, 2, 3, 4)
10471047

10481048
expected = model.predict(data)
10491049
assert runner('repeat_vector', onnx_model, data, expected)
@@ -1134,7 +1134,7 @@ def test_pooling_global(pooling_runner):
11341134
keras.activations.linear,
11351135
])
11361136
def test_activation_layer(runner, layer):
1137-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1137+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
11381138
layer = Activation(layer, input_shape=(data.size,))
11391139

11401140
model = keras.Sequential()
@@ -1177,13 +1177,13 @@ def test_selu(runner):
11771177

11781178

11791179
def test_LeakyReLU(advanced_activation_runner):
1180-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1180+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
11811181
layer = advanced_activations.LeakyReLU(alpha=0.1, input_shape=(data.size,))
11821182
advanced_activation_runner(layer, data)
11831183

11841184

11851185
def test_ThresholdedReLU(advanced_activation_runner):
1186-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1186+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
11871187
layer = advanced_activations.ThresholdedReLU(theta=1.0, input_shape=(data.size,))
11881188
advanced_activation_runner(layer, data, op_version=8)
11891189

@@ -1192,13 +1192,13 @@ def test_ThresholdedReLU(advanced_activation_runner):
11921192

11931193

11941194
def test_ELU(advanced_activation_runner):
1195-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1195+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
11961196
layer = advanced_activations.ELU(alpha=1.0, input_shape=(data.size,))
11971197
advanced_activation_runner(layer, data)
11981198

11991199

12001200
def test_PReLU(advanced_activation_runner):
1201-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1201+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
12021202
layer = advanced_activations.PReLU(alpha_initializer='zeros', input_shape=(data.size,))
12031203
advanced_activation_runner(layer, data)
12041204
layer = advanced_activations.PReLU(alpha_initializer='ones', input_shape=(data.size,))
@@ -1208,7 +1208,7 @@ def test_PReLU(advanced_activation_runner):
12081208

12091209

12101210
def test_Softmax(advanced_activation_runner):
1211-
data = asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
1211+
data = _asarray(-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5)
12121212
layer = advanced_activations.Softmax(axis=-1, input_shape=(data.size,))
12131213
advanced_activation_runner(layer, data)
12141214

@@ -1327,8 +1327,8 @@ def func(l2Normalize, input1, input2):
13271327

13281328

13291329
def test_dot(dot_runner):
1330-
dot_runner(False, asarray(1, 2, 3), asarray(4, 5, 6))
1331-
dot_runner(True, asarray(1, 2, 3), asarray(4, 5, 6))
1330+
dot_runner(False, _asarray(1, 2, 3), _asarray(4, 5, 6))
1331+
dot_runner(True, _asarray(1, 2, 3), _asarray(4, 5, 6))
13321332

13331333

13341334
def test_dot2(runner):
@@ -1397,7 +1397,7 @@ def func(data, gamma, beta, scale, center, axis):
13971397

13981398

13991399
def test_batch_normalization(batch_norm_runner):
1400-
data = asarray([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
1400+
data = _asarray([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
14011401
batch_norm_runner(data, 'ones', 'zeros', True, True, 3)
14021402
batch_norm_runner(data, 'ones', 'ones', True, True, 3)
14031403
# The CPU implementation of FusedBatchNorm only supports NHWC tensor format in tf keras

0 commit comments

Comments
 (0)