Skip to content

Commit f64b67c

Browse files
committed
Merge pull request #798 from sl-sergei:named_params_onnx
2 parents d1242e6 + b4d7f80 commit f64b67c

28 files changed

+118
-0
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
256 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+97
Original file line numberDiff line numberDiff line change
@@ -1040,3 +1040,100 @@ def forward(self, x):
10401040
x = Variable(torch.zeros([1, 2, 2]))
10411041
model = GatherMultiOutput()
10421042
save_data_and_model("gather_multi_output", x, model)
1043+
1044+
def postprocess_model(model_path, inputs_shapes):
1045+
onnx_model = onnx.load(model_path)
1046+
1047+
def update_inputs_dims(model, input_dims):
1048+
"""
1049+
This function updates the sizes of dimensions of the model's inputs to the values
1050+
provided in input_dims. if the dim value provided is negative, a unique dim_param
1051+
will be set for that dimension.
1052+
"""
1053+
def update_dim(tensor, dim, i, j, dim_param_prefix):
1054+
dim_proto = tensor.type.tensor_type.shape.dim[j]
1055+
if isinstance(dim, int):
1056+
if dim >= 0:
1057+
dim_proto.dim_value = dim
1058+
else:
1059+
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
1060+
elif isinstance(dim, str):
1061+
dim_proto.dim_param = dim
1062+
else:
1063+
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
1064+
1065+
for i, input_dim_arr in enumerate(input_dims):
1066+
for j, dim in enumerate(input_dim_arr):
1067+
update_dim(model.graph.input[i], dim, i, j, 'in_')
1068+
1069+
onnx.checker.check_model(model)
1070+
return model
1071+
1072+
onnx_model = update_inputs_dims(onnx_model, inputs_shapes)
1073+
onnx.save(onnx_model, model_path)
1074+
1075+
class UnsqueezeAndConv(nn.Module):
1076+
def __init__(self):
1077+
super(UnsqueezeAndConv, self).__init__()
1078+
self.conv = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
1079+
def forward(self, x):
1080+
x = x.unsqueeze(axis=0)
1081+
out = self.conv(x)
1082+
return out
1083+
1084+
x = Variable(torch.randn(3, 10, 10))
1085+
model = UnsqueezeAndConv()
1086+
save_data_and_model("unsqueeze_and_conv_dynamic_axes", x, model)
1087+
postprocess_model("models/unsqueeze_and_conv_dynamic_axes.onnx", [[3, 'height', 'width']])
1088+
1089+
class SqueezeAndConv(nn.Module):
1090+
def __init__(self):
1091+
super(SqueezeAndConv, self).__init__()
1092+
self.conv = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
1093+
def forward(self, x):
1094+
x = x.squeeze()
1095+
out = self.conv(x)
1096+
return out
1097+
1098+
x = Variable(torch.randn(2, 1, 3, 3, 3))
1099+
model = SqueezeAndConv()
1100+
save_data_and_model("squeeze_and_conv_dynamic_axes", x, model)
1101+
postprocess_model("models/squeeze_and_conv_dynamic_axes.onnx", [["batch_size", 1, "channels", 'height', 'width']])
1102+
1103+
x = Variable(torch.randn(2))
1104+
model = GatherScalar()
1105+
save_data_and_model("gather_scalar_dynamic_axes", x, model)
1106+
postprocess_model("models/gather_scalar_dynamic_axes.onnx", [['shape']])
1107+
1108+
x = Variable(torch.randn(2, 2, 2, 2))
1109+
print(x)
1110+
model = Gather()
1111+
print(model(x))
1112+
print(model(x).shape)
1113+
save_data_and_model("gather_dynamic_axes", x, model)
1114+
postprocess_model("models/gather_dynamic_axes.onnx", [["batch_size", 2, 'height', 'width']])
1115+
1116+
input = Variable(torch.randn(1, 2, 4, 4))
1117+
model = Slice()
1118+
save_data_and_model("slice_dynamic_axes", input, model)
1119+
save_data_and_model("slice_opset_11_dynamic_axes", input, model, version=11)
1120+
postprocess_model("models/slice_dynamic_axes.onnx", [["batch_size", 2, 'height', 'width']])
1121+
postprocess_model("models/slice_opset_11_dynamic_axes.onnx", [["batch_size", 2, 'height', 'width']])
1122+
1123+
x = Variable(torch.rand(1, 2, 2, 2))
1124+
model = ResizeConv(2, 0, 2)
1125+
save_data_and_model("resize_opset11_torch1.6_dynamic_axes", x, model, 11)
1126+
postprocess_model("models/resize_opset11_torch1.6_dynamic_axes.onnx", [["batch_size", 2, 'height', 'width']])
1127+
1128+
maxpooling_sigmoid = nn.Sequential(
1129+
nn.MaxPool2d(kernel_size=4, stride=2, padding=(1, 2), dilation=1),
1130+
nn.Sigmoid()
1131+
)
1132+
input = Variable(torch.randn(2, 3, 12, 18))
1133+
save_data_and_model("maxpooling_sigmoid_dynamic_axes", input, maxpooling_sigmoid)
1134+
postprocess_model("models/maxpooling_sigmoid_dynamic_axes.onnx", [[2, 3, 'height', 'width']])
1135+
1136+
ave_pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
1137+
input = Variable(torch.randn(1, 3, 7, 5))
1138+
save_data_and_model("average_pooling_dynamic_axes", input, ave_pool)
1139+
postprocess_model("models/average_pooling_dynamic_axes.onnx", [[1, 3, 'height', 'width']])
Binary file not shown.
211 Bytes
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
pytorch1.6:�
2+
V
3+
01 MaxPool_0"MaxPool*
4+
kernel_shape@@�*
5+
pads@@@@�*
6+
strides@@�
7+

8+
12 Sigmoid_1"Sigmoidtorch-jit-exportZ&
9+
0!
10+

11+

12+

13+
height
14+
widthb
15+
2
16+

17+

18+

19+

20+

21+
B
Binary file not shown.
259 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)