Skip to content

Commit a4327e1

Browse files
committed
Fixed Conv3DTranspose with strides for data format channels_first (fixes onnx#1714)
While shape calculations for the input correctly distinguished between channels_first and channels_last, shape calculations for the inputs of the final Slice and Pad nodes always assumed channels_last format. Signed-off-by: fthielke <[email protected]>
1 parent 4245d8d commit a4327e1

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

Diff for: tf2onnx/onnx_opset/nn.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs):
504504
use_strides_workaround = False
505505
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
506506
output_shape = ctx.make_node("Shape", [node.output[0]])
507+
sp_index_start = 1 if is_channels_last(node) else 2
507508
output_h = GraphBuilder(ctx).make_slice(
508-
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
509+
{"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
509510
output_w = GraphBuilder(ctx).make_slice(
510-
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
511+
{"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
511512
expect_h = GraphBuilder(ctx).make_slice(
512-
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
513+
{"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
513514
expect_w = GraphBuilder(ctx).make_slice(
514-
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
515+
{"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
515516
diff_h = ctx.make_node("Sub", [output_h, expect_h])
516517
diff_w = ctx.make_node("Sub", [output_w, expect_w])
517518
nonneg_diff_h = diff_h
@@ -528,10 +529,12 @@ def version_1(cls, ctx, node, **kwargs):
528529
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
529530
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
530531
if spatial == 3:
531-
output_d = GraphBuilder(ctx).make_slice(
532-
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
533-
expect_d = GraphBuilder(ctx).make_slice(
534-
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
532+
output_d = GraphBuilder(ctx).make_slice({
533+
"data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]
534+
})
535+
expect_d = GraphBuilder(ctx).make_slice({
536+
"data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]
537+
})
535538
diff_d = ctx.make_node("Sub", [output_d, expect_d])
536539
nonneg_diff_d = diff_d
537540
if use_strides_workaround:
@@ -543,12 +546,12 @@ def version_1(cls, ctx, node, **kwargs):
543546
attr={"axis": 0})
544547
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
545548
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
546-
np.array([1, 2, 3], dtype=np.int64))
549+
np.arange(sp_index_start, sp_index_start + 3, dtype=np.int64))
547550
else:
548551
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
549552
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
550553
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
551-
np.array([1, 2], dtype=np.int64))
554+
np.arange(sp_index_start, sp_index_start + 2, dtype=np.int64))
552555

553556
slice_node = ctx.make_node("Slice",
554557
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
@@ -571,10 +574,16 @@ def version_1(cls, ctx, node, **kwargs):
571574
neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]])
572575
shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]])
573576
sdb = shrink_d_by.output[0]
574-
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
577+
if is_channels_last(node):
578+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
579+
else:
580+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, cz, shb, swb, sdb], attr={"axis": 0})
575581
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
576582
else:
577-
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
583+
if is_channels_last(node):
584+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
585+
else:
586+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb], attr={"axis": 0})
578587
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
579588

580589
final_node = padded_node

0 commit comments

Comments
 (0)