Skip to content

Commit d4e16b5

Browse files
author
Yu Cong
committed
Add masked LSTM support
Signed-off-by: Yu Cong <[email protected]>
1 parent bc677a1 commit d4e16b5

File tree

3 files changed

+260
-25
lines changed

3 files changed

+260
-25
lines changed

Diff for: tests/test_lstm.py

+159
Original file line numberDiff line numberDiff line change
@@ -793,5 +793,164 @@ def func(x):
793793
return tf.identity(y[0], name="output")
794794
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)
795795

796+
@check_tf_min_version("2.0")
797+
@skip_tf_versions("2.1", "Bug in TF 2.1")
798+
def test_keras_masked_lstm_embedding_unidirectional(self):
799+
for go_backwards in [True, False]:
800+
timesteps = 4
801+
# Note: masked LSTM only support post-padded input after conversion
802+
# test case sequence_lens = [4, 2, 0]
803+
x_val = np.array([
804+
[1, 2, 3, 4],
805+
[5, 6, 0, 0],
806+
[0, 0, 0, 0]
807+
], dtype=np.int32)
808+
809+
model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
810+
x_embedding = tf.keras.layers.Embedding(
811+
input_dim=10,
812+
output_dim=5,
813+
mask_zero=True,
814+
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
815+
)(model_in)
816+
817+
# RNN layer inherits the mask propagated from above embedding layer
818+
model_out = tf.keras.layers.LSTM(
819+
units=5,
820+
go_backwards=go_backwards,
821+
return_state=True,
822+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
823+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
824+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
825+
)(x_embedding)
826+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
827+
828+
def func(x):
829+
y = model(x)
830+
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
831+
return(tf.identity(y[1], name="output_yh"),
832+
tf.identity(y[2], name="output_yc"))
833+
834+
output_list = ["output_yh:0", "output_yc:0"]
835+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)
836+
837+
@check_tf_min_version("2.0")
838+
@skip_tf_versions("2.1", "Bug in TF 2.1")
839+
def test_keras_masked_lstm_embedding_bidirectional(self):
840+
timesteps = 4
841+
# Note: masked LSTM only support post-padded input after conversion
842+
# test case sequence_lens = [4, 2, 0]
843+
x_val = np.array([
844+
[1, 2, 3, 4],
845+
[5, 6, 0, 0],
846+
[0, 0, 0, 0]
847+
], dtype=np.int32)
848+
849+
model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
850+
x_embedding = tf.keras.layers.Embedding(
851+
input_dim=10,
852+
output_dim=5,
853+
mask_zero=True,
854+
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
855+
)(model_in)
856+
857+
# RNN layer inherits the mask propagated from above embedding layer
858+
lstm_layer = tf.keras.layers.LSTM(
859+
units=5,
860+
go_backwards=False,
861+
return_state=True,
862+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
863+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
864+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
865+
)
866+
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding)
867+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
868+
869+
def func(x):
870+
y = model(x)
871+
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
872+
return(tf.identity(y[1], name="output_yh_f"),
873+
tf.identity(y[2], name="output_yc_f"),
874+
tf.identity(y[3], name="output_yh_r"),
875+
tf.identity(y[4], name="output_yc_r"))
876+
877+
output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
878+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
879+
require_lstm_count=2)
880+
881+
@check_tf_min_version("2.0")
882+
@skip_tf_versions("2.1", "Bug in TF 2.1")
883+
def test_keras_masked_lstm_unidirectional(self):
884+
for go_backwards in [True, False]:
885+
batch_size, timesteps, feat = 3, 4, 5
886+
in_shape = (timesteps, feat)
887+
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
888+
# Note: masked LSTM only support post-padded input after conversion
889+
# test case sequence_lens = [4, 2, 0]
890+
x_val[1, 2:, :] = 0.
891+
x_val[2, :, :] = 0.
892+
893+
model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
894+
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)
895+
896+
# RNN layer inherits the mask propagated from above mask layer
897+
model_out = tf.keras.layers.LSTM(
898+
units=5,
899+
go_backwards=go_backwards,
900+
return_state=True,
901+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
902+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
903+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
904+
)(x_masked)
905+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
906+
907+
def func(x):
908+
y = model(x)
909+
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
910+
return(tf.identity(y[1], name="output_yh"),
911+
tf.identity(y[2], name="output_yc"))
912+
913+
output_list = ["output_yh:0", "output_yc:0"]
914+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)
915+
916+
@check_tf_min_version("2.0")
917+
@skip_tf_versions("2.1", "Bug in TF 2.1")
918+
def test_keras_masked_lstm_bidirectional(self):
919+
batch_size, timesteps, feat = 3, 4, 5
920+
in_shape = (timesteps, feat)
921+
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
922+
# Note: masked LSTM only support post-padded input after conversion
923+
# test case sequence_lens = [4, 2, 0]
924+
x_val[1, 2:, :] = 0.
925+
x_val[2, :, :] = 0.
926+
927+
model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
928+
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)
929+
930+
# RNN layer inherits the mask propagated from above mask layer
931+
lstm_layer = tf.keras.layers.LSTM(
932+
units=5,
933+
go_backwards=False,
934+
return_state=True,
935+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
936+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
937+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
938+
)
939+
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked)
940+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
941+
942+
def func(x):
943+
y = model(x)
944+
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
945+
return (tf.identity(y[1], name="output_yh_f"),
946+
tf.identity(y[2], name="output_yc_f"),
947+
tf.identity(y[3], name="output_yh_r"),
948+
tf.identity(y[4], name="output_yc_r"))
949+
950+
output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
951+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
952+
require_lstm_count=2)
953+
954+
796955
if __name__ == '__main__':
797956
unittest_main()

Diff for: tf2onnx/onnx_opset/tensor.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -2260,15 +2260,24 @@ def version_10(cls, ctx, node, **kwargs):
22602260
const_axis_name = utils.make_name(f'const_{axis}')
22612261
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))
22622262

2263-
# Add a Constant node (seq_len) for ReverseSequence.
2264-
# Index 1 for the shape should not return 0, since rank(input) >=2
2265-
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
2266-
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
2267-
op_name_scope=rv2_node_name)
2268-
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
2269-
op_name_scope=rv2_node_name)
2270-
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
2271-
inputs.append(seq_array.output[0])
2263+
# Add sequence_lens as ReverseSequence input
2264+
has_sequence_lens = node.get_attr_value("has_sequence_lens", False)
2265+
if not has_sequence_lens:
2266+
# Add a Constant node (seq_len) for ReverseSequence.
2267+
# Index 1 for the shape should not return 0, since rank(input) >=2
2268+
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
2269+
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
2270+
op_name_scope=rv2_node_name)
2271+
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
2272+
op_name_scope=rv2_node_name)
2273+
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
2274+
inputs.append(seq_array.output[0])
2275+
else:
2276+
# masked backward LSTM:
2277+
# sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter
2278+
# to keep tensor post-padded after reverse
2279+
seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0]
2280+
inputs.append(seq_lens_casted)
22722281

22732282
# Add a ReverseSequence node.
22742283

Diff for: tf2onnx/rewriter/lstm_tf2_rewriter.py

+83-16
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
"""
55
tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2.
66
"""
7-
7+
import logging
88
import numpy as np
9+
from onnx import onnx_pb
10+
911
from tf2onnx.graph_matcher import GraphMatcher
1012
from tf2onnx.rewriter.rnn_utils import make_lstm_pattern
1113
from tf2onnx.tf_loader import find_function
@@ -79,21 +81,35 @@ def rewriter_lstm_tf2(g, ops):
7981
# extract output h_t
8082
ht_mul = match_result.get_op("ht")
8183
final_consumers = g.find_output_consumers(ht_mul.output[0])
82-
select_ops = [n for n in final_consumers if n.type == "Select"]
84+
select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
8385
def has_tensor_list_consumer(n):
8486
return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
8587
select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
88+
89+
# extract sequence length
90+
seq_len_idx, mask_idx = None, None
8691
if len(select_ops) == 1:
87-
greater_eq = select_ops[0].inputs[0]
88-
if greater_eq.type != "GreaterEqual":
89-
continue
90-
seq_len = greater_eq.inputs[1]
91-
if not seq_len.is_graph_input():
92+
select_op_condition = select_ops[0].inputs[0]
93+
while select_op_condition.type == "Identity":
94+
select_op_condition = select_op_condition.inputs[0]
95+
96+
# skip timestpes based on speicific sequence length
97+
if select_op_condition.type == "GreaterEqual":
98+
seq_len = select_op_condition.inputs[1]
99+
if not seq_len.is_graph_input():
100+
continue
101+
seq_len_idx = g.input_names.index(seq_len.output[0])
102+
103+
# masked LSTM: skip timesteps based on dynamically-computed boolean mask tensor
104+
elif select_op_condition.type == "TensorListGetItem":
105+
mask = select_op_condition.inputs[0]
106+
if not mask.is_graph_input():
107+
continue
108+
mask_idx = g.input_names.index(mask.output[0])
109+
else:
92110
continue
93-
seq_len_idx = g.input_names.index(seq_len.output[0])
111+
94112
final_consumers = g.find_output_consumers(select_ops[0].output[0])
95-
else:
96-
seq_len_idx = None
97113

98114
tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"]
99115
if len(tensor_set_items) != 1:
@@ -209,6 +225,7 @@ def has_tensor_list_consumer(n):
209225
# Keras
210226
"w_idx": gk_idx,
211227
"r_idx": hk_idx,
228+
"mask_idx": mask_idx,
212229
}
213230

214231
for op in ops:
@@ -276,15 +293,63 @@ def has_tensor_list_consumer(n):
276293
tensor_array_inp = op.inputs[body_context["x_idx"]]
277294
if not tensor_array_inp.type == "TensorListFromTensor":
278295
continue
296+
context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
279297

280-
final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
281-
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
298+
# parse sequence length
299+
seq_len_idx = body_context["seq_len_idx"]
300+
mask_idx = body_context["mask_idx"]
301+
if seq_len_idx:
302+
context.onnx_input_ids[0]["sequence_lens"] = op.input[seq_len_idx]
303+
elif mask_idx:
304+
logging.warning(
305+
"Found mask-enabled LSTM. Converted ONNX model will only support post-padded LSTM input. "
306+
"If input is pre- or randomly-padded, masked timesteps will not be correctly skipped.")
307+
308+
# parse sequence length
309+
tensor_array_mask = op.inputs[body_context["mask_idx"]]
310+
if not tensor_array_mask.type == "TensorListFromTensor":
311+
continue
312+
mask_mat = tensor_array_mask.input[0]
313+
mask_mat_node = g.get_node_by_output(mask_mat)
314+
is_mask_reverse = mask_mat_node.type == "ReverseV2"
315+
# no need to reverse the mask sequence
316+
# the positions of skipped timesteps per batch is irrelevant assuming post-padded input
317+
if is_mask_reverse:
318+
mask_mat = mask_mat_node.input[0]
319+
320+
# reduce mask tensor to sequence_lens assuming post-padded input
321+
# tranpose (1,0,2) -> boolean mask tensor (N, timesteps, 1)
322+
# squeeze on dim(-1) -> boolean mask matrix (N, timesteps)
323+
# reduceSum on dim(-1) -> sequence_lens (N)
324+
mask_transpose_node = g.make_node(op_type="Transpose", inputs=[mask_mat], attr={"perm": [1, 0, 2]})
325+
mask_squeeze = GraphBuilder(g).make_squeeze({"data": mask_transpose_node.output[0], "axes": [-1]})
326+
mask_cast_node = g.make_node(op_type="Cast", inputs=[mask_squeeze],
327+
attr={"to": onnx_pb.TensorProto.INT32})
328+
sequence_lens = GraphBuilder(g).make_reduce_sum({"data": mask_cast_node.output[0],
329+
"axes": [-1], "keepdims": 0})
330+
context.onnx_input_ids[0]["sequence_lens"] = sequence_lens
331+
332+
# handle backward LSTM
333+
tensor_array_inp_producer = tensor_array_inp.inputs[0]
334+
is_input_reverse = tensor_array_inp_producer.type == "ReverseV2"
335+
# backward LSTM is identified by the reverses of both input and mask tensors pre-LSTM
336+
if is_mask_reverse != is_input_reverse:
337+
continue
338+
if is_input_reverse:
339+
# TF uses simple "ReverseV2" to reverse input tensor with no assumption on padding position
340+
# because reversed mask with shape (batch_size, timesteps) is explicit per-timestep.
341+
# ONNX requires "ReverseSequence" to keep the reversed input tensor post-padded because mask
342+
# is implied by sequence_lens. This requires passing sequence_lens to such "ReverseSequence" op.
343+
344+
# Note: tensor op conversions run after rewriters. Appending sequence_lens as a "ReverseV2" input
345+
# signalizes alternative behavior in "ReverseV2" conversion in onnx_opset/tensor.py.
346+
tensor_array_inp_producer.set_attr("has_sequence_lens", True)
347+
inp_reverse_inputs = tensor_array_inp_producer.input
348+
inp_reverse_inputs.append(sequence_lens)
282349

283-
context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
284-
if body_context["seq_len_idx"] is None:
285-
context.onnx_input_ids[0]["sequence_lens"] = ""
286350
else:
287-
context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]]
351+
context.onnx_input_ids[0]["sequence_lens"] = ""
352+
288353
context.onnx_input_ids[0]["initial_c"] = initial_c
289354
context.onnx_input_ids[0]["initial_h"] = initial_h
290355

@@ -295,6 +360,8 @@ def has_tensor_list_consumer(n):
295360
lstm_node = lstm_rewriter.create_rnn_node(context)[0]
296361

297362
squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]})
363+
final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
364+
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
298365
for output in output_ys:
299366
g.replace_all_inputs(output, squeeze_output)
300367

0 commit comments

Comments
 (0)