@@ -1081,4 +1081,29 @@ def forward(self, x):
1081
1081
x = Variable (torch .randn (2 , 1 , 3 , 3 , 3 ))
1082
1082
model = SqueezeAndConv ()
1083
1083
save_data_and_model ("squeeze_and_conv_dynamic_axes" , x , model )
1084
- postprocess_model ("models/squeeze_and_conv_dynamic_axes.onnx" , [["batch_size" , 1 , 3 , 'height' , 'width' ]])
1084
+ postprocess_model ("models/squeeze_and_conv_dynamic_axes.onnx" , [["batch_size" , 1 , "channels" , 'height' , 'width' ]])
1085
+
1086
+ x = Variable (torch .randn (2 ))
1087
+ model = GatherScalar ()
1088
+ save_data_and_model ("gather_scalar_dynamic_axes" , x , model )
1089
+ postprocess_model ("models/gather_scalar_dynamic_axes.onnx" , [['shape' ]])
1090
+
1091
+ x = Variable (torch .randn (2 , 2 , 2 , 2 ))
1092
+ print (x )
1093
+ model = Gather ()
1094
+ print (model (x ))
1095
+ print (model (x ).shape )
1096
+ save_data_and_model ("gather_dynamic_axes" , x , model )
1097
+ postprocess_model ("models/gather_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1098
+
1099
+ input = Variable (torch .randn (1 , 2 , 4 , 4 ))
1100
+ model = Slice ()
1101
+ save_data_and_model ("slice_dynamic_axes" , input , model )
1102
+ save_data_and_model ("slice_opset_11_dynamic_axes" , input , model , version = 11 )
1103
+ postprocess_model ("models/slice_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1104
+ postprocess_model ("models/slice_opset_11_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
1105
+
1106
+ x = Variable (torch .rand (1 , 2 , 2 , 2 ))
1107
+ model = ResizeConv (2 , 0 , 2 )
1108
+ save_data_and_model ("resize_opset11_torch1.6_dynamic_axes" , x , model , 11 )
1109
+ postprocess_model ("models/resize_opset11_torch1.6_dynamic_axes.onnx" , [["batch_size" , 2 , 'height' , 'width' ]])
0 commit comments