@@ -378,17 +378,62 @@ def forward(self, x):
378
378
379
379
class Slice (nn .Module ):
380
380
381
- def __init__ (self ):
381
+ def __init__ (self , custom_slice = None ):
382
+ self .custom_slice = custom_slice
382
383
super (Slice , self ).__init__ ()
383
384
384
385
def forward (self , x ):
386
+ if self .custom_slice :
387
+ return x [self .custom_slice ]
388
+
385
389
return x [..., 1 :- 1 , 0 :3 ]
386
390
387
391
input = Variable (torch .randn (1 , 2 , 4 , 4 ))
388
392
model = Slice ()
389
393
save_data_and_model ("slice" , input , model )
390
394
save_data_and_model ("slice_opset_11" , input , model , version = 11 )
391
395
396
+ input_2 = Variable (torch .randn (6 , 6 ))
397
+ custom_slice_list = [
398
+ slice (1 , 3 , 1 ),
399
+ slice (0 , 3 , 2 )
400
+ ]
401
+ model_2 = Slice (custom_slice = custom_slice_list )
402
+ save_data_and_model ("slice_opset_11_steps_2d" , input_2 , model_2 , version = 11 )
403
+ postprocess_model ("models/slice_opset_11_steps_2d.onnx" , [['height' , 'width' ]])
404
+
405
+ input_3 = Variable (torch .randn (3 , 6 , 6 ))
406
+ custom_slice_list_3 = [
407
+ slice (None , None , 2 ),
408
+ slice (None , None , 2 ),
409
+ slice (None , None , 2 )
410
+ ]
411
+ model_3 = Slice (custom_slice = custom_slice_list_3 )
412
+ save_data_and_model ("slice_opset_11_steps_3d" , input_3 , model_3 , version = 11 )
413
+ postprocess_model ("models/slice_opset_11_steps_3d.onnx" , [[3 , 'height' , 'width' ]])
414
+
415
+ input_4 = Variable (torch .randn (1 , 3 , 6 , 6 ))
416
+ custom_slice_list_4 = [
417
+ slice (0 , 5 , None ),
418
+ slice (None , None , None ),
419
+ slice (1 , None , 2 ),
420
+ slice (None , None , None )
421
+ ]
422
+ model_4 = Slice (custom_slice = custom_slice_list_4 )
423
+ save_data_and_model ("slice_opset_11_steps_4d" , input_4 , model_4 , version = 11 )
424
+ postprocess_model ("models/slice_opset_11_steps_4d.onnx" , [["batch_size" , 3 , 'height' , 'width' ]])
425
+
426
+ input_5 = Variable (torch .randn (1 , 2 , 3 , 6 , 6 ))
427
+ custom_slice_list_5 = [
428
+ slice (None , None , None ),
429
+ slice (None , None , None ),
430
+ slice (0 , None , 3 ),
431
+ slice (None , None , None ),
432
+ slice (None , None , 2 )
433
+ ]
434
+ model_5 = Slice (custom_slice = custom_slice_list_5 )
435
+ save_data_and_model ("slice_opset_11_steps_5d" , input_5 , model_5 , version = 11 )
436
+
392
437
class Eltwise (nn .Module ):
393
438
394
439
def __init__ (self ):
0 commit comments