Skip to content

Commit 9c38839

Browse files
author
Anastasia Murzova
committed
Added Steps support in DNN Slice layer
1 parent 71f2370 commit 9c38839

13 files changed

+46
-1
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,62 @@ def forward(self, x):
378378

379379
class Slice(nn.Module):
380380

381-
def __init__(self):
381+
def __init__(self, custom_sliice=None):
382+
self.custom_sliice=custom_sliice
382383
super(Slice, self).__init__()
383384

384385
def forward(self, x):
386+
if self.custom_sliice:
387+
return x[self.custom_sliice]
388+
385389
return x[..., 1:-1, 0:3]
386390

387391
input = Variable(torch.randn(1, 2, 4, 4))
388392
model = Slice()
389393
save_data_and_model("slice", input, model)
390394
save_data_and_model("slice_opset_11", input, model, version=11)
391395

396+
input_2 = Variable(torch.randn(6, 6))
397+
custom_slice_list = [
398+
slice(1, 3, 1),
399+
slice(0, 3, 2)
400+
]
401+
model_2 = Slice(custom_sliice=custom_slice_list)
402+
save_data_and_model("slice_opset_11_steps_2d", input_2, model_2, version=11)
403+
postprocess_model("models/slice_opset_11_steps_2d.onnx", [['height', 'width']])
404+
405+
input_3 = Variable(torch.randn(3, 6, 6))
406+
custom_slice_list_3 = [
407+
slice(None, None, 2),
408+
slice(None, None, 2),
409+
slice(None, None, 2)
410+
]
411+
model_3 = Slice(custom_sliice=custom_slice_list_3)
412+
save_data_and_model("slice_opset_11_steps_3d", input_3, model_3, version=11)
413+
postprocess_model("models/slice_opset_11_steps_3d.onnx", [[3, 'height', 'width']])
414+
415+
input_4 = Variable(torch.randn(1, 3, 6, 6))
416+
custom_slice_list_4 = [
417+
slice(0, 5, None),
418+
slice(None, None, None),
419+
slice(1, None, 2),
420+
slice(None, None, None)
421+
]
422+
model_4 = Slice(custom_sliice=custom_slice_list_4)
423+
save_data_and_model("slice_opset_11_steps_4d", input_4, model_4, version=11)
424+
postprocess_model("models/slice_opset_11_steps_4d.onnx", [["batch_size", 3, 'height', 'width']])
425+
426+
input_5 = Variable(torch.randn(1, 2, 3, 6, 6))
427+
custom_slice_list_5 = [
428+
slice(None, None, None),
429+
slice(None, None, None),
430+
slice(0, None, 3),
431+
slice(None, None, None),
432+
slice(None, None, 2)
433+
]
434+
model_5 = Slice(custom_sliice=custom_slice_list_5)
435+
save_data_and_model("slice_opset_11_steps_5d", input_5, model_5, version=11)
436+
392437
class Eltwise(nn.Module):
393438

394439
def __init__(self):
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)