Skip to content

Commit 7474473

Browse files
committed
Add testdata for networks with parametrized input dims
1 parent c3d6342 commit 7474473

4 files changed

+45
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+45
Original file line numberDiff line numberDiff line change
@@ -986,3 +986,48 @@ def forward(self, x):
986986
x = Variable(torch.randn(1, 3, 2, 2))
987987
model = Scale()
988988
save_data_and_model("scale", x, model)
989+
990+
def postprocess_model(model_path):
991+
onnx_model = onnx.load(model_path)
992+
993+
def update_inputs_dims(model, input_dims):
994+
"""
995+
This function updates the sizes of dimensions of the model's inputs to the values
996+
provided in input_dims. if the dim value provided is negative, a unique dim_param
997+
will be set for that dimension.
998+
"""
999+
def update_dim(tensor, dim, i, j, dim_param_prefix):
1000+
dim_proto = tensor.type.tensor_type.shape.dim[j]
1001+
if isinstance(dim, int):
1002+
if dim >= 0:
1003+
dim_proto.dim_value = dim
1004+
else:
1005+
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
1006+
elif isinstance(dim, str):
1007+
dim_proto.dim_param = dim
1008+
else:
1009+
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
1010+
1011+
for i, input_dim_arr in enumerate(input_dims):
1012+
for j, dim in enumerate(input_dim_arr):
1013+
update_dim(model.graph.input[i], dim, i, j, 'in_')
1014+
1015+
onnx.checker.check_model(model)
1016+
return model
1017+
1018+
onnx_model = update_inputs_dims(onnx_model, [[3, 'height', 'width']])
1019+
onnx.save(onnx_model, model_path)
1020+
1021+
class ReshapeAndConv(nn.Module):
1022+
def __init__(self):
1023+
super(ReshapeAndConv, self).__init__()
1024+
self.conv = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
1025+
def forward(self, x):
1026+
x = x.unsqueeze(axis=0)
1027+
out = self.conv(x)
1028+
return out
1029+
1030+
x = Variable(torch.randn(3, 10, 10))
1031+
model = ReshapeAndConv()
1032+
save_data_and_model("reshape_and_conv_parameter_dims", x, model)
1033+
postprocess_model("models/reshape_and_conv_parameter_dims.onnx")
Binary file not shown.

0 commit comments

Comments
 (0)