@@ -1040,3 +1040,100 @@ def forward(self, x):
1040
1040
x = Variable (torch .zeros ([1 , 2 , 2 ]))
1041
1041
model = GatherMultiOutput ()
1042
1042
save_data_and_model ("gather_multi_output" , x , model )
1043
+
1044
+ def postprocess_model (model_path , inputs_shapes ):
1045
+ onnx_model = onnx .load (model_path )
1046
+
1047
+ def update_inputs_dims (model , input_dims ):
1048
+ """
1049
+ This function updates the sizes of dimensions of the model's inputs to the values
1050
+ provided in input_dims. if the dim value provided is negative, a unique dim_param
1051
+ will be set for that dimension.
1052
+ """
1053
+ def update_dim (tensor , dim , i , j , dim_param_prefix ):
1054
+ dim_proto = tensor .type .tensor_type .shape .dim [j ]
1055
+ if isinstance (dim , int ):
1056
+ if dim >= 0 :
1057
+ dim_proto .dim_value = dim
1058
+ else :
1059
+ dim_proto .dim_param = dim_param_prefix + str (i ) + '_' + str (j )
1060
+ elif isinstance (dim , str ):
1061
+ dim_proto .dim_param = dim
1062
+ else :
1063
+ raise ValueError ('Only int or str is accepted as dimension value, incorrect type: {}' .format (type (dim )))
1064
+
1065
+ for i , input_dim_arr in enumerate (input_dims ):
1066
+ for j , dim in enumerate (input_dim_arr ):
1067
+ update_dim (model .graph .input [i ], dim , i , j , 'in_' )
1068
+
1069
+ onnx .checker .check_model (model )
1070
+ return model
1071
+
1072
+ onnx_model = update_inputs_dims (onnx_model , inputs_shapes )
1073
+ onnx .save (onnx_model , model_path )
1074
+
1075
+ class UnsqueezeAndConv (nn .Module ):
1076
+ def __init__ (self ):
1077
+ super (UnsqueezeAndConv , self ).__init__ ()
1078
+ self .conv = nn .Conv2d (3 , 3 , kernel_size = 1 , stride = 1 , padding = 0 )
1079
+ def forward (self , x ):
1080
+ x = x .unsqueeze (axis = 0 )
1081
+ out = self .conv (x )
1082
+ return out
1083
+
1084
+ x = Variable (torch .randn (3 , 10 , 10 ))
1085
+ model = UnsqueezeAndConv ()
1086
+ save_data_and_model ("unsqueeze_and_conv_dynamic_axes" , x , model )
1087
+ postprocess_model ("models/unsqueeze_and_conv_dynamic_axes.onnx" , [[3 , 'height' , 'width' ]])
1088
+
1089
+ class SqueezeAndConv (nn .Module ):
1090
+ def __init__ (self ):
1091
+ super (SqueezeAndConv , self ).__init__ ()
1092
+ self .conv = nn .Conv2d (3 , 3 , kernel_size = 1 , stride = 1 , padding = 0 )
1093
+ def forward (self , x ):
1094
+ x = x .squeeze ()
1095
+ out = self .conv (x )
1096
+ return out
1097
+
1098
+ x = Variable (torch .randn (2 , 1 , 3 , 3 , 3 ))
1099
+ model = SqueezeAndConv ()
1100
+ save_data_and_model ("squeeze_and_conv_dynamic_axes" , x , model )
1101
+ postprocess_model ("models/squeeze_and_conv_dynamic_axes.onnx" , [["batch_size" , 1 , "channels" , 'height' , 'width' ]])
1102
+
1103
+ x = Variable (torch .randn (2 ))
1104
+ model = GatherScalar ()
1105
+ save_data_and_model ("gather_scalar_dynamic_axes" , x , model )
1106
+ postprocess_model ("models/gather_scalar_dynamic_axes.onnx" , [['shape' ]])
1107
+
1108
+ x = Variable (torch .randn (2 , 2 , 2 , 2 ))
1109
+ print (x )
1110
+ model = Gather ()
1111
+ print (model (x ))
1112
+ print (model (x ).shape )
1113
+ save_data_and_model ("gather_dynamic_axes" , x , model )
1114
+ postprocess_model ("models/gather_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1115
+
1116
+ input = Variable (torch .randn (1 , 2 , 4 , 4 ))
1117
+ model = Slice ()
1118
+ save_data_and_model ("slice_dynamic_axes" , input , model )
1119
+ save_data_and_model ("slice_opset_11_dynamic_axes" , input , model , version = 11 )
1120
+ postprocess_model ("models/slice_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1121
+ postprocess_model ("models/slice_opset_11_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1122
+
1123
+ x = Variable (torch .rand (1 , 2 , 2 , 2 ))
1124
+ model = ResizeConv (2 , 0 , 2 )
1125
+ save_data_and_model ("resize_opset11_torch1.6_dynamic_axes" , x , model , 11 )
1126
+ postprocess_model ("models/resize_opset11_torch1.6_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1127
+
1128
+ maxpooling_sigmoid = nn .Sequential (
1129
+ nn .MaxPool2d (kernel_size = 4 , stride = 2 , padding = (1 , 2 ), dilation = 1 ),
1130
+ nn .Sigmoid ()
1131
+ )
1132
+ input = Variable (torch .randn (2 , 3 , 12 , 18 ))
1133
+ save_data_and_model ("maxpooling_sigmoid_dynamic_axes" , input , maxpooling_sigmoid )
1134
+ postprocess_model ("models/maxpooling_sigmoid_dynamic_axes.onnx" , [[2 , 3 , 'height' , 'width' ]])
1135
+
1136
+ ave_pool = nn .AvgPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
1137
+ input = Variable (torch .randn (1 , 3 , 7 , 5 ))
1138
+ save_data_and_model ("average_pooling_dynamic_axes" , input , ave_pool )
1139
+ postprocess_model ("models/average_pooling_dynamic_axes.onnx" , [[1 , 3 , 'height' , 'width' ]])
0 commit comments