@@ -1093,6 +1093,126 @@ def forward(self, a):
1093
1093
def FullLikeModuleFalsePinMemory_basic (module , tu : TestUtils ):
1094
1094
module .forward (tu .randint (10 , 4 , high = 100 ))
1095
1095
1096
+ # ==============================================================================
1097
+
1098
+
1099
+ class NewFullModuleDefaultDtype (torch .nn .Module ):
1100
+
1101
+ def __init__ (self ):
1102
+ super ().__init__ ()
1103
+
1104
+ @export
1105
+ @annotate_args ([
1106
+ None ,
1107
+ ([- 1 , - 1 ], torch .float32 , True ),
1108
+ ])
1109
+ def forward (self , a ):
1110
+ return torch .ops .aten .new_full (a , (3 ,4 ), 5 )
1111
+
1112
+
1113
+ @register_test_case (module_factory = lambda : NewFullModuleDefaultDtype ())
1114
+ def NewFullModuleDefaultDtype_basic (module , tu : TestUtils ):
1115
+ module .forward (tu .rand (2 , 3 ))
1116
+
1117
+
1118
+ class NewFullModuleInt2D (torch .nn .Module ):
1119
+
1120
+ def __init__ (self ):
1121
+ super ().__init__ ()
1122
+
1123
+ @export
1124
+ @annotate_args ([
1125
+ None ,
1126
+ ([- 1 , - 1 ], torch .int64 , True ),
1127
+ ])
1128
+ def forward (self , a ):
1129
+ return torch .ops .aten .new_full (a , (3 ,4 ), 10.5 )
1130
+
1131
+
1132
+ @register_test_case (module_factory = lambda : NewFullModuleInt2D ())
1133
+ def NewFullModuleInt2D_basic (module , tu : TestUtils ):
1134
+ module .forward (tu .randint (4 , 5 , high = 10 ))
1135
+
1136
+
1137
+ class NewFullModuleInt3D (torch .nn .Module ):
1138
+
1139
+ def __init__ (self ):
1140
+ super ().__init__ ()
1141
+
1142
+ @export
1143
+ @annotate_args ([
1144
+ None ,
1145
+ ([- 1 , - 1 , - 1 ], torch .int32 , True ),
1146
+ ])
1147
+ def forward (self , a ):
1148
+ return torch .ops .aten .new_full (a , (3 ,4 ), 5.0 , dtype = torch .int64 )
1149
+
1150
+
1151
+ @register_test_case (module_factory = lambda : NewFullModuleInt3D ())
1152
+ def NewFullModuleInt3D_basic (module , tu : TestUtils ):
1153
+ module .forward (tu .randint (10 , 4 , 5 , high = 100 ).to (torch .int32 ))
1154
+
1155
+
1156
+ class NewFullModuleFloat3D (torch .nn .Module ):
1157
+
1158
+ def __init__ (self ):
1159
+ super ().__init__ ()
1160
+
1161
+ @export
1162
+ @annotate_args ([
1163
+ None ,
1164
+ ([- 1 , - 1 , - 1 ], torch .float64 , True ),
1165
+ ])
1166
+ def forward (self , a ):
1167
+ return torch .ops .aten .new_full (a , (3 ,4 ), 15 , dtype = torch .float32 )
1168
+
1169
+
1170
+ @register_test_case (module_factory = lambda : NewFullModuleFloat3D ())
1171
+ def NewFullModuleFloat3D_basic (module , tu : TestUtils ):
1172
+ module .forward (tu .rand (3 , 4 , 5 ).to (torch .float64 ))
1173
+
1174
+
1175
+ class NewFullModuleFloat3DStatic (torch .nn .Module ):
1176
+
1177
+ def __init__ (self ):
1178
+ super ().__init__ ()
1179
+
1180
+ @export
1181
+ @annotate_args ([
1182
+ None ,
1183
+ ([3 , 4 , 5 ], torch .float64 , True ),
1184
+ ])
1185
+ def forward (self , a ):
1186
+ return torch .ops .aten .new_full (a , (3 ,4 ), 15.3 , dtype = torch .float32 )
1187
+
1188
+
1189
+ @register_test_case (module_factory = lambda : NewFullModuleFloat3DStatic ())
1190
+ def NewFullModuleFloat3DStatic_basic (module , tu : TestUtils ):
1191
+ module .forward (tu .rand (3 , 4 , 5 ).to (torch .float64 ))
1192
+
1193
+
1194
+ class NewFullModuleFalsePinMemory (torch .nn .Module ):
1195
+
1196
+ def __init__ (self ):
1197
+ super ().__init__ ()
1198
+
1199
+ @export
1200
+ @annotate_args ([
1201
+ None ,
1202
+ ([- 1 , - 1 ], torch .int64 , True ),
1203
+ ])
1204
+ def forward (self , a ):
1205
+ return torch .ops .aten .new_full (a ,
1206
+ (3 ,4 ),
1207
+ 5 ,
1208
+ dtype = torch .int64 ,
1209
+ pin_memory = False )
1210
+
1211
+
1212
+ @register_test_case (module_factory = lambda : NewFullModuleFalsePinMemory ())
1213
+ def NewFullModuleFalsePinMemory_basic (module , tu : TestUtils ):
1214
+ module .forward (tu .randint (10 , 4 , high = 100 ))
1215
+
1096
1216
1097
1217
# ==============================================================================
1098
1218
0 commit comments