@@ -504,14 +504,15 @@ def version_1(cls, ctx, node, **kwargs):
504
504
use_strides_workaround = False
505
505
input_shape = ctx .make_node ("Cast" , [node .input [0 ]], attr = {'to' : TensorProto .INT64 })
506
506
output_shape = ctx .make_node ("Shape" , [node .output [0 ]])
507
+ sp_index_start = 1 if is_channels_last (node ) else 2
507
508
output_h = GraphBuilder (ctx ).make_slice (
508
- {"data" : output_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
509
+ {"data" : output_shape .output [0 ], "ends" : [sp_index_start + 1 ], "starts" : [sp_index_start ], "axes" : [0 ]})
509
510
output_w = GraphBuilder (ctx ).make_slice (
510
- {"data" : output_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
511
+ {"data" : output_shape .output [0 ], "ends" : [sp_index_start + 2 ], "starts" : [sp_index_start + 1 ], "axes" : [0 ]})
511
512
expect_h = GraphBuilder (ctx ).make_slice (
512
- {"data" : input_shape .output [0 ], "ends" : [2 ], "starts" : [1 ], "axes" : [0 ]})
513
+ {"data" : input_shape .output [0 ], "ends" : [sp_index_start + 1 ], "starts" : [sp_index_start ], "axes" : [0 ]})
513
514
expect_w = GraphBuilder (ctx ).make_slice (
514
- {"data" : input_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
515
+ {"data" : input_shape .output [0 ], "ends" : [sp_index_start + 2 ], "starts" : [sp_index_start + 1 ], "axes" : [0 ]})
515
516
diff_h = ctx .make_node ("Sub" , [output_h , expect_h ])
516
517
diff_w = ctx .make_node ("Sub" , [output_w , expect_w ])
517
518
nonneg_diff_h = diff_h
@@ -528,10 +529,12 @@ def version_1(cls, ctx, node, **kwargs):
528
529
end_h = ctx .make_node ("Add" , [start_h .output [0 ], expect_h ])
529
530
end_w = ctx .make_node ("Add" , [start_w .output [0 ], expect_w ])
530
531
if spatial == 3 :
531
- output_d = GraphBuilder (ctx ).make_slice (
532
- {"data" : output_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
533
- expect_d = GraphBuilder (ctx ).make_slice (
534
- {"data" : input_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
532
+ output_d = GraphBuilder (ctx ).make_slice ({
533
+ "data" : output_shape .output [0 ], "ends" : [sp_index_start + 3 ], "starts" : [sp_index_start + 2 ], "axes" : [0 ]
534
+ })
535
+ expect_d = GraphBuilder (ctx ).make_slice ({
536
+ "data" : input_shape .output [0 ], "ends" : [sp_index_start + 3 ], "starts" : [sp_index_start + 2 ], "axes" : [0 ]
537
+ })
535
538
diff_d = ctx .make_node ("Sub" , [output_d , expect_d ])
536
539
nonneg_diff_d = diff_d
537
540
if use_strides_workaround :
@@ -543,12 +546,12 @@ def version_1(cls, ctx, node, **kwargs):
543
546
attr = {"axis" : 0 })
544
547
ends = ctx .make_node ("Concat" , [end_h .output [0 ], end_w .output [0 ], end_d .output [0 ]], attr = {"axis" : 0 })
545
548
slice_axes = ctx .make_const (utils .make_name (node .name + "_const_slice_axes" ),
546
- np .array ([ 1 , 2 , 3 ] , dtype = np .int64 ))
549
+ np .arange ( sp_index_start , sp_index_start + 3 , dtype = np .int64 ))
547
550
else :
548
551
starts = ctx .make_node ("Concat" , [start_h .output [0 ], start_w .output [0 ]], attr = {"axis" : 0 })
549
552
ends = ctx .make_node ("Concat" , [end_h .output [0 ], end_w .output [0 ]], attr = {"axis" : 0 })
550
553
slice_axes = ctx .make_const (utils .make_name (node .name + "_const_slice_axes" ),
551
- np .array ([ 1 , 2 ] , dtype = np .int64 ))
554
+ np .arange ( sp_index_start , sp_index_start + 2 , dtype = np .int64 ))
552
555
553
556
slice_node = ctx .make_node ("Slice" ,
554
557
[node .output [0 ], starts .output [0 ], ends .output [0 ], slice_axes .output [0 ]],
@@ -571,10 +574,16 @@ def version_1(cls, ctx, node, **kwargs):
571
574
neg_diff_d = ctx .make_node ("Neg" , [diff_d .output [0 ]])
572
575
shrink_d_by = ctx .make_node ("Max" , [neg_diff_d .output [0 ], const_zero .output [0 ]])
573
576
sdb = shrink_d_by .output [0 ]
574
- pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
577
+ if is_channels_last (node ):
578
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
579
+ else :
580
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , cz , shb , swb , sdb ], attr = {"axis" : 0 })
575
581
padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
576
582
else :
577
- pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
583
+ if is_channels_last (node ):
584
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
585
+ else :
586
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb ], attr = {"axis" : 0 })
578
587
padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
579
588
580
589
final_node = padded_node
0 commit comments