@@ -986,3 +986,48 @@ def forward(self, x):
986
986
x = Variable (torch .randn (1 , 3 , 2 , 2 ))
987
987
model = Scale ()
988
988
save_data_and_model ("scale" , x , model )
989
+
990
+ def postprocess_model (model_path ):
991
+ onnx_model = onnx .load (model_path )
992
+
993
+ def update_inputs_dims (model , input_dims ):
994
+ """
995
+ This function updates the sizes of dimensions of the model's inputs to the values
996
+ provided in input_dims. if the dim value provided is negative, a unique dim_param
997
+ will be set for that dimension.
998
+ """
999
+ def update_dim (tensor , dim , i , j , dim_param_prefix ):
1000
+ dim_proto = tensor .type .tensor_type .shape .dim [j ]
1001
+ if isinstance (dim , int ):
1002
+ if dim >= 0 :
1003
+ dim_proto .dim_value = dim
1004
+ else :
1005
+ dim_proto .dim_param = dim_param_prefix + str (i ) + '_' + str (j )
1006
+ elif isinstance (dim , str ):
1007
+ dim_proto .dim_param = dim
1008
+ else :
1009
+ raise ValueError ('Only int or str is accepted as dimension value, incorrect type: {}' .format (type (dim )))
1010
+
1011
+ for i , input_dim_arr in enumerate (input_dims ):
1012
+ for j , dim in enumerate (input_dim_arr ):
1013
+ update_dim (model .graph .input [i ], dim , i , j , 'in_' )
1014
+
1015
+ onnx .checker .check_model (model )
1016
+ return model
1017
+
1018
+ onnx_model = update_inputs_dims (onnx_model , [[3 , 'height' , 'width' ]])
1019
+ onnx .save (onnx_model , model_path )
1020
+
1021
+ class ReshapeAndConv (nn .Module ):
1022
+ def __init__ (self ):
1023
+ super (ReshapeAndConv , self ).__init__ ()
1024
+ self .conv = nn .Conv2d (3 , 3 , kernel_size = 1 , stride = 1 , padding = 0 )
1025
+ def forward (self , x ):
1026
+ x = x .unsqueeze (axis = 0 )
1027
+ out = self .conv (x )
1028
+ return out
1029
+
1030
+ x = Variable (torch .randn (3 , 10 , 10 ))
1031
+ model = ReshapeAndConv ()
1032
+ save_data_and_model ("reshape_and_conv_parameter_dims" , x , model )
1033
+ postprocess_model ("models/reshape_and_conv_parameter_dims.onnx" )
0 commit comments