Skip to content

Commit ae4099b

Browse files
author
Anastasia M
authored
Merge pull request #854 from LupusSanctus:am/slice_steps
* Added Steps support in DNN Slice layer * Added minor code corrections
1 parent 5e02386 commit ae4099b

13 files changed

+46
-1
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,62 @@ def forward(self, x):
378378

379379
class Slice(nn.Module):
380380

381-
def __init__(self):
381+
def __init__(self, custom_slice=None):
382+
self.custom_slice=custom_slice
382383
super(Slice, self).__init__()
383384

384385
def forward(self, x):
386+
if self.custom_slice:
387+
return x[self.custom_slice]
388+
385389
return x[..., 1:-1, 0:3]
386390

387391
input = Variable(torch.randn(1, 2, 4, 4))
388392
model = Slice()
389393
save_data_and_model("slice", input, model)
390394
save_data_and_model("slice_opset_11", input, model, version=11)
391395

396+
input_2 = Variable(torch.randn(6, 6))
397+
custom_slice_list = [
398+
slice(1, 3, 1),
399+
slice(0, 3, 2)
400+
]
401+
model_2 = Slice(custom_slice=custom_slice_list)
402+
save_data_and_model("slice_opset_11_steps_2d", input_2, model_2, version=11)
403+
postprocess_model("models/slice_opset_11_steps_2d.onnx", [['height', 'width']])
404+
405+
input_3 = Variable(torch.randn(3, 6, 6))
406+
custom_slice_list_3 = [
407+
slice(None, None, 2),
408+
slice(None, None, 2),
409+
slice(None, None, 2)
410+
]
411+
model_3 = Slice(custom_slice=custom_slice_list_3)
412+
save_data_and_model("slice_opset_11_steps_3d", input_3, model_3, version=11)
413+
postprocess_model("models/slice_opset_11_steps_3d.onnx", [[3, 'height', 'width']])
414+
415+
input_4 = Variable(torch.randn(1, 3, 6, 6))
416+
custom_slice_list_4 = [
417+
slice(0, 5, None),
418+
slice(None, None, None),
419+
slice(1, None, 2),
420+
slice(None, None, None)
421+
]
422+
model_4 = Slice(custom_slice=custom_slice_list_4)
423+
save_data_and_model("slice_opset_11_steps_4d", input_4, model_4, version=11)
424+
postprocess_model("models/slice_opset_11_steps_4d.onnx", [["batch_size", 3, 'height', 'width']])
425+
426+
input_5 = Variable(torch.randn(1, 2, 3, 6, 6))
427+
custom_slice_list_5 = [
428+
slice(None, None, None),
429+
slice(None, None, None),
430+
slice(0, None, 3),
431+
slice(None, None, None),
432+
slice(None, None, 2)
433+
]
434+
model_5 = Slice(custom_slice=custom_slice_list_5)
435+
save_data_and_model("slice_opset_11_steps_5d", input_5, model_5, version=11)
436+
392437
class Eltwise(nn.Module):
393438

394439
def __init__(self):
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)