Skip to content

Commit 693cb9e

Browse files
author
Anastasia Murzova
committed
Added Steps support in DNN Slice layer
1 parent 5e02386 commit 693cb9e

13 files changed

+46
-1
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,62 @@ def forward(self, x):
378378

379379
class Slice(nn.Module):
380380

381-
def __init__(self):
381+
def __init__(self, custom_sliice=None):
382+
self.custom_sliice=custom_sliice
382383
super(Slice, self).__init__()
383384

384385
def forward(self, x):
386+
if self.custom_sliice:
387+
return x[self.custom_sliice]
388+
385389
return x[..., 1:-1, 0:3]
386390

387391
input = Variable(torch.randn(1, 2, 4, 4))
388392
model = Slice()
389393
save_data_and_model("slice", input, model)
390394
save_data_and_model("slice_opset_11", input, model, version=11)
391395

396+
input_2 = Variable(torch.randn(6, 6))
397+
custom_slice_list = [
398+
slice(1, 3, 1),
399+
slice(0, 3, 2)
400+
]
401+
model_2 = Slice(custom_sliice=custom_slice_list)
402+
save_data_and_model("slice_opset_11_steps_2d", input_2, model_2, version=11)
403+
postprocess_model("models/slice_opset_11_steps_2d.onnx", [['height', 'width']])
404+
405+
input_3 = Variable(torch.randn(3, 6, 6))
406+
custom_slice_list_3 = [
407+
slice(None, None, 2),
408+
slice(None, None, 2),
409+
slice(None, None, 2)
410+
]
411+
model_3 = Slice(custom_sliice=custom_slice_list_3)
412+
save_data_and_model("slice_opset_11_steps_3d", input_3, model_3, version=11)
413+
postprocess_model("models/slice_opset_11_steps_3d.onnx", [[3, 'height', 'width']])
414+
415+
input_4 = Variable(torch.randn(1, 3, 6, 6))
416+
custom_slice_list_4 = [
417+
slice(0, 5, None),
418+
slice(None, None, None),
419+
slice(1, None, 2),
420+
slice(None, None, None)
421+
]
422+
model_4 = Slice(custom_sliice=custom_slice_list_4)
423+
save_data_and_model("slice_opset_11_steps_4d", input_4, model_4, version=11)
424+
postprocess_model("models/slice_opset_11_steps_4d.onnx", [["batch_size", 3, 'height', 'width']])
425+
426+
input_5 = Variable(torch.randn(1, 2, 3, 6, 6))
427+
custom_slice_list_5 = [
428+
slice(None, None, None),
429+
slice(None, None, None),
430+
slice(0, None, 3),
431+
slice(None, None, None),
432+
slice(None, None, 2)
433+
]
434+
model_5 = Slice(custom_sliice=custom_slice_list_5)
435+
save_data_and_model("slice_opset_11_steps_5d", input_5, model_5, version=11)
436+
392437
class Eltwise(nn.Module):
393438

394439
def __init__(self):
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)