-
Notifications
You must be signed in to change notification settings - Fork 438
/
Copy pathcustom_op_via_python.py
36 lines (30 loc) · 1.27 KB
/
custom_op_via_python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# SPDX-License-Identifier: Apache-2.0
"""
A simple example how to map a custom op in python.
"""
import tensorflow as tf
import tf2onnx
from onnx import helper
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
def print_handler(ctx, node, name, args):
# replace tf.Print() with Identity
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
# becomes:
# T output = Identity(T Input)
node.type = "Identity"
node.domain = _TENSORFLOW_DOMAIN
del node.input[1:]
return node
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 3], name="input")
x_ = tf.add(x, x)
x_ = tf.Print(x_, [x_], "hello")
_ = tf.identity(x_, name="output")
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
custom_op_handlers={"Print": (print_handler, [])},
extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)],
input_names=["input:0"],
output_names=["output:0"])
model_proto = onnx_graph.make_model("test")
with open("/tmp/model.onnx", "wb") as f:
f.write(model_proto.SerializeToString())