Skip to content

Commit 7dee77f

Browse files
committed
Add testdata for networks with parametrized input dims
1 parent 9cbde3f commit 7dee77f

4 files changed

+45
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+45
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,48 @@ def forward(self, x, kernel, bias):
887887
x = Variable(torch.randn(1, 2, 2))
888888
model = Expand(shape=[2, -1, -1, -1])
889889
save_data_and_model("expand_neg_batch", x, model)
890+
891+
def postprocess_model(model_path):
892+
onnx_model = onnx.load(model_path)
893+
894+
def update_inputs_dims(model, input_dims):
895+
"""
896+
This function updates the sizes of dimensions of the model's inputs to the values
897+
provided in input_dims. if the dim value provided is negative, a unique dim_param
898+
will be set for that dimension.
899+
"""
900+
def update_dim(tensor, dim, i, j, dim_param_prefix):
901+
dim_proto = tensor.type.tensor_type.shape.dim[j]
902+
if isinstance(dim, int):
903+
if dim >= 0:
904+
dim_proto.dim_value = dim
905+
else:
906+
dim_proto.dim_param = dim_param_prefix + str(i) + '_' + str(j)
907+
elif isinstance(dim, str):
908+
dim_proto.dim_param = dim
909+
else:
910+
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
911+
912+
for i, input_dim_arr in enumerate(input_dims):
913+
for j, dim in enumerate(input_dim_arr):
914+
update_dim(model.graph.input[i], dim, i, j, 'in_')
915+
916+
onnx.checker.check_model(model)
917+
return model
918+
919+
onnx_model = update_inputs_dims(onnx_model, [[3, 'height', 'width']])
920+
onnx.save(onnx_model, model_path)
921+
922+
class ReshapeAndConv(nn.Module):
923+
def __init__(self):
924+
super(ReshapeAndConv, self).__init__()
925+
self.conv = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
926+
def forward(self, x):
927+
x = x.unsqueeze(axis=0)
928+
out = self.conv(x)
929+
return out
930+
931+
x = Variable(torch.randn(3, 10, 10))
932+
model = ReshapeAndConv()
933+
save_data_and_model("reshape_and_conv_parameter_dims", x, model)
934+
postprocess_model("models/reshape_and_conv_parameter_dims.onnx")
Binary file not shown.

0 commit comments

Comments
 (0)