@@ -887,3 +887,48 @@ def forward(self, x, kernel, bias):
887
887
x = Variable (torch .randn (1 , 2 , 2 ))
888
888
model = Expand (shape = [2 , - 1 , - 1 , - 1 ])
889
889
save_data_and_model ("expand_neg_batch" , x , model )
890
+
891
+ def postprocess_model (model_path ):
892
+ onnx_model = onnx .load (model_path )
893
+
894
+ def update_inputs_dims (model , input_dims ):
895
+ """
896
+ This function updates the sizes of dimensions of the model's inputs to the values
897
+ provided in input_dims. if the dim value provided is negative, a unique dim_param
898
+ will be set for that dimension.
899
+ """
900
+ def update_dim (tensor , dim , i , j , dim_param_prefix ):
901
+ dim_proto = tensor .type .tensor_type .shape .dim [j ]
902
+ if isinstance (dim , int ):
903
+ if dim >= 0 :
904
+ dim_proto .dim_value = dim
905
+ else :
906
+ dim_proto .dim_param = dim_param_prefix + str (i ) + '_' + str (j )
907
+ elif isinstance (dim , str ):
908
+ dim_proto .dim_param = dim
909
+ else :
910
+ raise ValueError ('Only int or str is accepted as dimension value, incorrect type: {}' .format (type (dim )))
911
+
912
+ for i , input_dim_arr in enumerate (input_dims ):
913
+ for j , dim in enumerate (input_dim_arr ):
914
+ update_dim (model .graph .input [i ], dim , i , j , 'in_' )
915
+
916
+ onnx .checker .check_model (model )
917
+ return model
918
+
919
+ onnx_model = update_inputs_dims (onnx_model , [[3 , 'height' , 'width' ]])
920
+ onnx .save (onnx_model , model_path )
921
+
922
+ class ReshapeAndConv (nn .Module ):
923
+ def __init__ (self ):
924
+ super (ReshapeAndConv , self ).__init__ ()
925
+ self .conv = nn .Conv2d (3 , 3 , kernel_size = 1 , stride = 1 , padding = 0 )
926
+ def forward (self , x ):
927
+ x = x .unsqueeze (axis = 0 )
928
+ out = self .conv (x )
929
+ return out
930
+
931
+ x = Variable (torch .randn (3 , 10 , 10 ))
932
+ model = ReshapeAndConv ()
933
+ save_data_and_model ("reshape_and_conv_parameter_dims" , x , model )
934
+ postprocess_model ("models/reshape_and_conv_parameter_dims.onnx" )
0 commit comments