@@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs):
504
504
use_strides_workaround = False
505
505
input_shape = ctx .make_node ("Cast" , [node .input [0 ]], attr = {'to' : TensorProto .INT64 })
506
506
output_shape = ctx .make_node ("Shape" , [node .output [0 ]])
507
+ sp_index_start = 1 if is_channels_last (node ) else 2
507
508
output_h = GraphBuilder (ctx ).make_slice (
508
- {"data" : output_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
509
+ {"data" : output_shape .output [0 ], "ends" : [sp_index_start + 1 ], "starts" : [sp_index_start ], "axes" : [0 ]})
509
510
output_w = GraphBuilder (ctx ).make_slice (
510
- {"data" : output_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
511
+ {"data" : output_shape .output [0 ], "ends" : [sp_index_start + 2 ], "starts" : [sp_index_start + 1 ], "axes" : [0 ]})
511
512
expect_h = GraphBuilder (ctx ).make_slice (
512
- {"data" : input_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
513
+ {"data" : input_shape .output [0 ], "ends" : [sp_index_start + 1 ], "starts" : [sp_index_start ], "axes" : [0 ]})
513
514
expect_w = GraphBuilder (ctx ).make_slice (
514
- {"data" : input_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
515
+ {"data" : input_shape .output [0 ], "ends" : [sp_index_start + 2 ], "starts" : [sp_index_start + 1 ], "axes" : [0 ]})
515
516
diff_h = ctx .make_node ("Sub" , [output_h , expect_h ])
516
517
diff_w = ctx .make_node ("Sub" , [output_w , expect_w ])
517
518
nonneg_diff_h = diff_h
@@ -529,9 +530,9 @@ def version_1(cls, ctx, node, **kwargs):
529
530
end_w = ctx .make_node ("Add" , [start_w .output [0 ], expect_w ])
530
531
if spatial == 3 :
531
532
output_d = GraphBuilder (ctx ).make_slice (
532
- {"data" : output_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
533
+ {"data" : output_shape .output [0 ], "ends" : [sp_index_start + 3 ], "starts" : [sp_index_start + 2 ], "axes" : [0 ]})
533
534
expect_d = GraphBuilder (ctx ).make_slice (
534
- {"data" : input_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
535
+ {"data" : input_shape .output [0 ], "ends" : [sp_index_start + 3 ], "starts" : [sp_index_start + 2 ], "axes" : [0 ]})
535
536
diff_d = ctx .make_node ("Sub" , [output_d , expect_d ])
536
537
nonneg_diff_d = diff_d
537
538
if use_strides_workaround :
@@ -543,12 +544,12 @@ def version_1(cls, ctx, node, **kwargs):
543
544
attr = {"axis" : 0 })
544
545
ends = ctx .make_node ("Concat" , [end_h .output [0 ], end_w .output [0 ], end_d .output [0 ]], attr = {"axis" : 0 })
545
546
slice_axes = ctx .make_const (utils .make_name (node .name + "_const_slice_axes" ),
546
- np .array ([ 1 , 2 , 3 ] , dtype = np .int64 ))
547
+ np .arange ( sp_index_start , sp_index_start + 3 , dtype = np .int64 ))
547
548
else :
548
549
starts = ctx .make_node ("Concat" , [start_h .output [0 ], start_w .output [0 ]], attr = {"axis" : 0 })
549
550
ends = ctx .make_node ("Concat" , [end_h .output [0 ], end_w .output [0 ]], attr = {"axis" : 0 })
550
551
slice_axes = ctx .make_const (utils .make_name (node .name + "_const_slice_axes" ),
551
- np .array ([ 1 , 2 ] , dtype = np .int64 ))
552
+ np .arange ( sp_index_start , sp_index_start + 2 , dtype = np .int64 ))
552
553
553
554
slice_node = ctx .make_node ("Slice" ,
554
555
[node .output [0 ], starts .output [0 ], ends .output [0 ], slice_axes .output [0 ]],
@@ -571,10 +572,16 @@ def version_1(cls, ctx, node, **kwargs):
571
572
neg_diff_d = ctx .make_node ("Neg" , [diff_d .output [0 ]])
572
573
shrink_d_by = ctx .make_node ("Max" , [neg_diff_d .output [0 ], const_zero .output [0 ]])
573
574
sdb = shrink_d_by .output [0 ]
574
- pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
575
+ if is_channels_last (node ):
576
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
577
+ else :
578
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , cz , shb , swb , sdb ], attr = {"axis" : 0 })
575
579
padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
576
580
else :
577
- pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
581
+ if is_channels_last (node ):
582
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
583
+ else :
584
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb ], attr = {"axis" : 0 })
578
585
padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
579
586
580
587
final_node = padded_node
0 commit comments