Skip to content

Commit c79f97b

Browse files
committed
Fixed Conv3DTranspose with strides for data format channels_first (fixes onnx#1714)
While shape calculations for the input correctly distinguished between channels_first and channels_last, shape calculations for the inputs of the final Slice and Pad nodes always assumed channels_last format. Signed-off-by: fthielke <[email protected]>
1 parent 4245d8d commit c79f97b

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

Diff for: tf2onnx/onnx_opset/nn.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs):
504504
use_strides_workaround = False
505505
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
506506
output_shape = ctx.make_node("Shape", [node.output[0]])
507+
sp_index_start = 1 if is_channels_last(node) else 2
507508
output_h = GraphBuilder(ctx).make_slice(
508-
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
509+
{"data": output_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
509510
output_w = GraphBuilder(ctx).make_slice(
510-
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
511+
{"data": output_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
511512
expect_h = GraphBuilder(ctx).make_slice(
512-
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
513+
{"data": input_shape.output[0], "ends": [sp_index_start+1], "starts": [sp_index_start], "axes": [0]})
513514
expect_w = GraphBuilder(ctx).make_slice(
514-
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
515+
{"data": input_shape.output[0], "ends": [sp_index_start+2], "starts": [sp_index_start+1], "axes": [0]})
515516
diff_h = ctx.make_node("Sub", [output_h, expect_h])
516517
diff_w = ctx.make_node("Sub", [output_w, expect_w])
517518
nonneg_diff_h = diff_h
@@ -529,9 +530,9 @@ def version_1(cls, ctx, node, **kwargs):
529530
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
530531
if spatial == 3:
531532
output_d = GraphBuilder(ctx).make_slice(
532-
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
533+
{"data": output_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]})
533534
expect_d = GraphBuilder(ctx).make_slice(
534-
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
535+
{"data": input_shape.output[0], "ends": [sp_index_start+3], "starts": [sp_index_start+2], "axes": [0]})
535536
diff_d = ctx.make_node("Sub", [output_d, expect_d])
536537
nonneg_diff_d = diff_d
537538
if use_strides_workaround:
@@ -543,12 +544,12 @@ def version_1(cls, ctx, node, **kwargs):
543544
attr={"axis": 0})
544545
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
545546
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
546-
np.array([1, 2, 3], dtype=np.int64))
547+
np.arange(sp_index_start, sp_index_start + 3, dtype=np.int64))
547548
else:
548549
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
549550
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
550551
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
551-
np.array([1, 2], dtype=np.int64))
552+
np.arange(sp_index_start, sp_index_start + 2, dtype=np.int64))
552553

553554
slice_node = ctx.make_node("Slice",
554555
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
@@ -571,10 +572,16 @@ def version_1(cls, ctx, node, **kwargs):
571572
neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]])
572573
shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]])
573574
sdb = shrink_d_by.output[0]
574-
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
575+
if is_channels_last(node):
576+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
577+
else:
578+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, cz, shb, swb, sdb], attr={"axis": 0})
575579
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
576580
else:
577-
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
581+
if is_channels_last(node):
582+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
583+
else:
584+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb], attr={"axis": 0})
578585
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
579586

580587
final_node = padded_node

0 commit comments

Comments
 (0)