Skip to content

Commit 82456ee

Browse files
authored
[MLIR][TORCH] add E2E support for aten.new_full (#2425)
* implement aten.new_full * remove extraneous tests
1 parent 23b7224 commit 82456ee

File tree

8 files changed

+221
-0
lines changed

8 files changed

+221
-0
lines changed

e2e_testing/xfail_sets.py

+14
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,14 @@
561561
"FullModuleFloat3D_basic",
562562
"FullModuleInt2D_basic",
563563
"FullModuleInt3D_basic",
564+
"NewFullModuleDefaultDtype_basic",
565+
"NewFullModuleFalsePinMemory_basic",
566+
"NewFullModuleFloat2D_basic",
567+
"NewFullModuleFloat3DStatic_basic",
568+
"NewFullModuleFloat3D_basic",
569+
"NewFullModuleInt2DStatic_basic",
570+
"NewFullModuleInt2D_basic",
571+
"NewFullModuleInt3D_basic",
564572
"GatherStaticModule_basic",
565573
"GatherModule_basic",
566574
"Gather2DInputModdule_basic",
@@ -1149,6 +1157,12 @@
11491157
"FullLikeModuleFloat3DStatic_basic",
11501158
"FullModuleDefaultDtype_basic",
11511159
"FullModuleFloat3D_basic",
1160+
"NewFullModuleDefaultDtype_basic",
1161+
"NewFullModuleFalsePinMemory_basic",
1162+
"NewFullModuleFloat2D_basic",
1163+
"NewFullModuleFloat3DStatic_basic",
1164+
"NewFullModuleFloat3D_basic",
1165+
"NewFullModuleInt2DStatic_basic",
11521166
"MaskedFillScalarDefaultModule_basic",
11531167
"NumToTensorFloatModule_basic",
11541168
"LiftFreshCopyModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+29
Original file line numberDiff line numberDiff line change
@@ -9853,6 +9853,35 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
98539853
}];
98549854
}
98559855

9856+
def Torch_AtenNewFullOp : Torch_Op<"aten.new_full", [
9857+
AllowsTypeRefinement,
9858+
HasValueSemantics,
9859+
ReadOnly
9860+
]> {
9861+
let summary = "Generated op for `aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
9862+
let arguments = (ins
9863+
AnyTorchTensorType:$self,
9864+
AnyTorchListOfTorchIntType:$size,
9865+
AnyTorchScalarType:$fill_value,
9866+
AnyTorchOptionalIntType:$dtype,
9867+
AnyTorchOptionalIntType:$layout,
9868+
AnyTorchOptionalDeviceType:$device,
9869+
AnyTorchOptionalBoolType:$pin_memory
9870+
);
9871+
let results = (outs
9872+
AnyTorchTensorType:$result
9873+
);
9874+
let hasCustomAssemblyFormat = 1;
9875+
let extraClassDefinition = [{
9876+
ParseResult AtenNewFullOp::parse(OpAsmParser &parser, OperationState &result) {
9877+
return parseDefaultTorchOp(parser, result, 7, 1);
9878+
}
9879+
void AtenNewFullOp::print(OpAsmPrinter &printer) {
9880+
printDefaultTorchOp(printer, *this, 7, 1);
9881+
}
9882+
}];
9883+
}
9884+
98569885
def Torch_AtenBaddbmmOp : Torch_Op<"aten.baddbmm", [
98579886
AllowsTypeRefinement,
98589887
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -7226,6 +7226,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
72267226
" func.func @\"__torch_mlir_shape_fn.aten.full_like\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
72277227
" return %arg0 : !torch.list<int>\n"
72287228
" }\n"
7229+
" func.func @\"__torch_mlir_shape_fn.aten.new_full\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
7230+
" return %arg1 : !torch.list<int>\n"
7231+
" }\n"
72297232
" func.func @\"__torch_mlir_shape_fn.aten.zeros_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
72307233
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
72317234
" return %0 : !torch.list<int>\n"
@@ -10542,6 +10545,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1054210545
" }\n"
1054310546
" return %2 : !torch.int\n"
1054410547
" }\n"
10548+
" func.func @\"__torch_mlir_dtype_fn.aten.new_full\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.number, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
10549+
" %none = torch.constant.none\n"
10550+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10551+
" %1 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
10552+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
10553+
" torch.prim.If.yield %0#1 : !torch.int\n"
10554+
" } else {\n"
10555+
" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
10556+
" torch.prim.If.yield %3 : !torch.int\n"
10557+
" }\n"
10558+
" return %2 : !torch.int\n"
10559+
" }\n"
1054510560
" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
1054610561
" %none = torch.constant.none\n"
1054710562
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -3166,6 +3166,33 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
31663166
};
31673167
} // namespace
31683168

3169+
namespace {
3170+
// Decompose `aten.new_full` op into `aten.full` op.
3171+
class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
3172+
public:
3173+
using OpRewritePattern::OpRewritePattern;
3174+
LogicalResult matchAndRewrite(AtenNewFullOp op,
3175+
PatternRewriter &rewriter) const override {
3176+
Value dtype = op.getDtype();
3177+
if (dtype.getType().isa<Torch::NoneType>()) {
3178+
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
3179+
if (!tensorType.hasDtype()) {
3180+
return rewriter.notifyMatchFailure(
3181+
op, "expected input tensor to have a dtype");
3182+
}
3183+
dtype =
3184+
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
3185+
}
3186+
rewriter.replaceOpWithNewOp<AtenFullOp>(
3187+
op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(),
3188+
op.getPinMemory());
3189+
3190+
return success();
3191+
3192+
}
3193+
};
3194+
} // namespace
3195+
31693196
namespace {
31703197
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
31713198
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
@@ -5177,6 +5204,7 @@ class DecomposeComplexOpsPass
51775204
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
51785205
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
51795206
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
5207+
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
51805208
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
51815209
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
51825210
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
437437
target.addIllegalOp<AtenLinearOp>();
438438
target.addIllegalOp<AtenMishOp>();
439439
target.addIllegalOp<AtenFullLikeOp>();
440+
target.addIllegalOp<AtenNewFullOp>();
440441
target.addIllegalOp<AtenIndexPutOp>();
441442
target.addIllegalOp<AtenExpandAsOp>();
442443
target.addIllegalOp<Aten_ToCopyOp>();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

+13
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,9 @@ def aten〇full〡shape(size: List[int], fill_value: float, dtype: Optional[int]
650650
def aten〇full_like〡shape(self: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
651651
return self
652652

653+
def aten〇new_full〡shape(self: List[int], size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
654+
return size
655+
653656
def aten〇zeros_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
654657
return upstream_shape_functions.unary(self)
655658

@@ -3244,6 +3247,16 @@ def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union
32443247
self_rank, self_dtype = self_rank_dtype
32453248
return self_dtype if dtype is None else dtype
32463249

3250+
@check_dtype_function(
3251+
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0) +
3252+
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0) +
3253+
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.float16) +
3254+
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.int32) +
3255+
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=(1,), fill_value=0.0, dtype=torch.complex64))
3256+
def aten〇new_full〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
3257+
self_rank, self_dtype = self_rank_dtype
3258+
return self_dtype if dtype is None else dtype
3259+
32473260
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) +
32483261
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) +
32493262
_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) +

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def emit_with_mutating_variants(key, **kwargs):
599599
emit("aten::numpy_T : (Tensor) -> (Tensor)")
600600
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
601601
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
602+
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
602603
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
603604
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
604605
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")

python/torch_mlir_e2e_test/test_suite/constant_alloc.py

+120
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,126 @@ def forward(self, a):
10931093
def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils):
10941094
module.forward(tu.randint(10, 4, high=100))
10951095

1096+
# ==============================================================================
1097+
1098+
1099+
class NewFullModuleDefaultDtype(torch.nn.Module):
1100+
1101+
def __init__(self):
1102+
super().__init__()
1103+
1104+
@export
1105+
@annotate_args([
1106+
None,
1107+
([-1, -1], torch.float32, True),
1108+
])
1109+
def forward(self, a):
1110+
return torch.ops.aten.new_full(a, (3,4), 5)
1111+
1112+
1113+
@register_test_case(module_factory=lambda: NewFullModuleDefaultDtype())
1114+
def NewFullModuleDefaultDtype_basic(module, tu: TestUtils):
1115+
module.forward(tu.rand(2, 3))
1116+
1117+
1118+
class NewFullModuleInt2D(torch.nn.Module):
1119+
1120+
def __init__(self):
1121+
super().__init__()
1122+
1123+
@export
1124+
@annotate_args([
1125+
None,
1126+
([-1, -1], torch.int64, True),
1127+
])
1128+
def forward(self, a):
1129+
return torch.ops.aten.new_full(a, (3,4), 10.5)
1130+
1131+
1132+
@register_test_case(module_factory=lambda: NewFullModuleInt2D())
1133+
def NewFullModuleInt2D_basic(module, tu: TestUtils):
1134+
module.forward(tu.randint(4, 5, high=10))
1135+
1136+
1137+
class NewFullModuleInt3D(torch.nn.Module):
1138+
1139+
def __init__(self):
1140+
super().__init__()
1141+
1142+
@export
1143+
@annotate_args([
1144+
None,
1145+
([-1, -1, -1], torch.int32, True),
1146+
])
1147+
def forward(self, a):
1148+
return torch.ops.aten.new_full(a, (3,4), 5.0, dtype=torch.int64)
1149+
1150+
1151+
@register_test_case(module_factory=lambda: NewFullModuleInt3D())
1152+
def NewFullModuleInt3D_basic(module, tu: TestUtils):
1153+
module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32))
1154+
1155+
1156+
class NewFullModuleFloat3D(torch.nn.Module):
1157+
1158+
def __init__(self):
1159+
super().__init__()
1160+
1161+
@export
1162+
@annotate_args([
1163+
None,
1164+
([-1, -1, -1], torch.float64, True),
1165+
])
1166+
def forward(self, a):
1167+
return torch.ops.aten.new_full(a, (3,4), 15, dtype=torch.float32)
1168+
1169+
1170+
@register_test_case(module_factory=lambda: NewFullModuleFloat3D())
1171+
def NewFullModuleFloat3D_basic(module, tu: TestUtils):
1172+
module.forward(tu.rand(3, 4, 5).to(torch.float64))
1173+
1174+
1175+
class NewFullModuleFloat3DStatic(torch.nn.Module):
1176+
1177+
def __init__(self):
1178+
super().__init__()
1179+
1180+
@export
1181+
@annotate_args([
1182+
None,
1183+
([3, 4, 5], torch.float64, True),
1184+
])
1185+
def forward(self, a):
1186+
return torch.ops.aten.new_full(a, (3,4), 15.3, dtype=torch.float32)
1187+
1188+
1189+
@register_test_case(module_factory=lambda: NewFullModuleFloat3DStatic())
1190+
def NewFullModuleFloat3DStatic_basic(module, tu: TestUtils):
1191+
module.forward(tu.rand(3, 4, 5).to(torch.float64))
1192+
1193+
1194+
class NewFullModuleFalsePinMemory(torch.nn.Module):
1195+
1196+
def __init__(self):
1197+
super().__init__()
1198+
1199+
@export
1200+
@annotate_args([
1201+
None,
1202+
([-1, -1], torch.int64, True),
1203+
])
1204+
def forward(self, a):
1205+
return torch.ops.aten.new_full(a,
1206+
(3,4),
1207+
5,
1208+
dtype=torch.int64,
1209+
pin_memory=False)
1210+
1211+
1212+
@register_test_case(module_factory=lambda: NewFullModuleFalsePinMemory())
1213+
def NewFullModuleFalsePinMemory_basic(module, tu: TestUtils):
1214+
module.forward(tu.randint(10, 4, high=100))
1215+
10961216

10971217
# ==============================================================================
10981218

0 commit comments

Comments
 (0)