Skip to content

Commit ddd45d6

Browse files
[MLIR][TORCH] Add E2E support for aten.new_zeros, aten.new_ones op
This commit adds lowering of `aten.new_zeros` and `aten.new_ones` op Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 1dba4fc commit ddd45d6

File tree

7 files changed

+335
-0
lines changed

7 files changed

+335
-0
lines changed

e2e_testing/torchscript/constant_alloc.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,209 @@ def forward(self, tensor):
514514
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
515515
module.forward(torch.randn(3, 2, 4).to(torch.float64))
516516

517+
518+
# ==============================================================================
519+
520+
class NewZerosModuleDefaultDtype(torch.nn.Module):
521+
def __init__(self):
522+
super().__init__()
523+
524+
@export
525+
@annotate_args([
526+
None,
527+
([-1, -1], torch.float32, True),
528+
])
529+
def forward(self, a):
530+
return torch.ops.aten.new_zeros(a, [3, 4])
531+
532+
@register_test_case(module_factory=lambda: NewZerosModuleDefaultDtype())
533+
def NewZerosModuleDefaultDtype_basic(module, tu: TestUtils):
534+
module.forward(tu.rand(2, 3))
535+
536+
537+
class NewZerosModuleInt2D(torch.nn.Module):
538+
def __init__(self):
539+
super().__init__()
540+
541+
@export
542+
@annotate_args([
543+
None,
544+
([-1, -1, -1], torch.float32, True),
545+
])
546+
def forward(self, a):
547+
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64)
548+
549+
@register_test_case(module_factory=lambda: NewZerosModuleInt2D())
550+
def NewZerosModuleInt2D_basic(module, tu: TestUtils):
551+
module.forward(tu.rand(2, 3, 4))
552+
553+
554+
class NewZerosModuleInt3D(torch.nn.Module):
555+
def __init__(self):
556+
super().__init__()
557+
558+
@export
559+
@annotate_args([
560+
None,
561+
([-1, -1], torch.float32, True),
562+
])
563+
def forward(self, a):
564+
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.int64)
565+
566+
@register_test_case(module_factory=lambda: NewZerosModuleInt3D())
567+
def NewZerosModuleInt3D_basic(module, tu: TestUtils):
568+
module.forward(tu.rand(2, 3))
569+
570+
571+
class NewZerosModuleFloat2D(torch.nn.Module):
572+
def __init__(self):
573+
super().__init__()
574+
575+
@export
576+
@annotate_args([
577+
None,
578+
([-1, -1, -1], torch.int64, True),
579+
])
580+
def forward(self, a):
581+
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32)
582+
583+
@register_test_case(module_factory=lambda: NewZerosModuleFloat2D())
584+
def NewZerosModuleFloat2D_basic(module, tu: TestUtils):
585+
module.forward(torch.randint(10, (2, 3, 4)))
586+
587+
588+
class NewZerosModuleFloat3D(torch.nn.Module):
589+
def __init__(self):
590+
super().__init__()
591+
592+
@export
593+
@annotate_args([
594+
None,
595+
([-1, -1], torch.int64, True),
596+
])
597+
def forward(self, a):
598+
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.float32)
599+
600+
@register_test_case(module_factory=lambda: NewZerosModuleFloat3D())
601+
def NewZerosModuleFloat3D_basic(module, tu: TestUtils):
602+
module.forward(torch.randint(10, (2, 3)))
603+
604+
605+
class NewZerosModuleFalsePinMemory(torch.nn.Module):
606+
def __init__(self):
607+
super().__init__()
608+
609+
@export
610+
@annotate_args([
611+
None,
612+
([-1, -1], torch.int64, True),
613+
])
614+
def forward(self, a):
615+
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32, pin_memory=False)
616+
617+
@register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory())
618+
def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
619+
module.forward(torch.randint(10, (2, 3)))
620+
621+
# ==============================================================================
622+
623+
class NewOnesModuleDefaultDtype(torch.nn.Module):
624+
def __init__(self):
625+
super().__init__()
626+
627+
@export
628+
@annotate_args([
629+
None,
630+
([-1, -1], torch.float32, True),
631+
])
632+
def forward(self, a):
633+
return torch.ops.aten.new_ones(a, [3, 4])
634+
635+
@register_test_case(module_factory=lambda: NewOnesModuleDefaultDtype())
636+
def NewOnesModuleDefaultDtype_basic(module, tu: TestUtils):
637+
module.forward(tu.rand(2, 3))
638+
639+
640+
class NewOnesModuleInt2D(torch.nn.Module):
641+
def __init__(self):
642+
super().__init__()
643+
644+
@export
645+
@annotate_args([
646+
None,
647+
([-1, -1, -1], torch.float32, True),
648+
])
649+
def forward(self, a):
650+
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64)
651+
652+
@register_test_case(module_factory=lambda: NewOnesModuleInt2D())
653+
def NewOnesModuleInt2D_basic(module, tu: TestUtils):
654+
module.forward(tu.rand(2, 3, 4))
655+
656+
657+
class NewOnesModuleInt3D(torch.nn.Module):
658+
def __init__(self):
659+
super().__init__()
660+
661+
@export
662+
@annotate_args([
663+
None,
664+
([-1, -1], torch.float32, True),
665+
])
666+
def forward(self, a):
667+
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.int64)
668+
669+
@register_test_case(module_factory=lambda: NewOnesModuleInt3D())
670+
def NewOnesModuleInt3D_basic(module, tu: TestUtils):
671+
module.forward(tu.rand(2, 3))
672+
673+
674+
class NewOnesModuleFloat2D(torch.nn.Module):
675+
def __init__(self):
676+
super().__init__()
677+
678+
@export
679+
@annotate_args([
680+
None,
681+
([-1, -1, -1], torch.int64, True),
682+
])
683+
def forward(self, a):
684+
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32)
685+
686+
@register_test_case(module_factory=lambda: NewOnesModuleFloat2D())
687+
def NewOnesModuleFloat2D_basic(module, tu: TestUtils):
688+
module.forward(torch.randint(10, (2, 3, 4)))
689+
690+
691+
class NewOnesModuleFloat3D(torch.nn.Module):
692+
def __init__(self):
693+
super().__init__()
694+
695+
@export
696+
@annotate_args([
697+
None,
698+
([-1, -1], torch.int64, True),
699+
])
700+
def forward(self, a):
701+
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.float32)
702+
703+
@register_test_case(module_factory=lambda: NewOnesModuleFloat3D())
704+
def NewOnesModuleFloat3D_basic(module, tu: TestUtils):
705+
module.forward(torch.randint(10, (2, 3)))
706+
707+
708+
class NewOnesModuleFalsePinMemory(torch.nn.Module):
709+
def __init__(self):
710+
super().__init__()
711+
712+
@export
713+
@annotate_args([
714+
None,
715+
([-1, -1], torch.int64, True),
716+
])
717+
def forward(self, a):
718+
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32, pin_memory=False)
719+
720+
@register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory())
721+
def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils):
722+
module.forward(torch.randint(10, (2, 3)))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,16 @@
117117
"OnesModuleInt_basic",
118118
"OnesModuleFloat_basic",
119119
"OnesModuleFalsePinMemory_basic",
120+
"NewZerosModuleDefaultDtype_basic",
121+
"NewZerosModuleInt2D_basic",
122+
"NewZerosModuleInt3D_basic",
123+
"NewZerosModuleFloat2D_basic",
124+
"NewZerosModuleFloat3D_basic",
125+
"NewZerosModuleFalsePinMemory_basic",
126+
"NewOnesModuleDefaultDtype_basic",
127+
"NewOnesModuleInt2D_basic",
128+
"NewOnesModuleInt3D_basic",
129+
"NewOnesModuleFloat2D_basic",
130+
"NewOnesModuleFloat3D_basic",
131+
"NewOnesModuleFalsePinMemory_basic",
120132
}

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,6 +2178,25 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
21782178
let assemblyFormat = "$size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` qualified(type($size)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `->` qualified(type($result))";
21792179
}
21802180

2181+
def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
2182+
AllowsTypeRefinement,
2183+
HasValueSemantics
2184+
]> {
2185+
let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
2186+
let arguments = (ins
2187+
AnyTorchTensorType:$self,
2188+
TorchIntListType:$size,
2189+
TorchOptionalIntType:$dtype,
2190+
TorchOptionalIntType:$layout,
2191+
TorchOptionalDeviceType:$device,
2192+
TorchOptionalBoolType:$pin_memory
2193+
);
2194+
let results = (outs
2195+
AnyTorchTensorType:$result
2196+
);
2197+
let assemblyFormat = "$self `,` $size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` qualified(type($self)) `,` qualified(type($size)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `->` qualified(type($result))";
2198+
}
2199+
21812200
def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
21822201
AllowsTypeRefinement,
21832202
HasValueSemantics
@@ -2196,6 +2215,25 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
21962215
let assemblyFormat = "$size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` qualified(type($size)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `->` qualified(type($result))";
21972216
}
21982217

2218+
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
2219+
AllowsTypeRefinement,
2220+
HasValueSemantics
2221+
]> {
2222+
let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
2223+
let arguments = (ins
2224+
AnyTorchTensorType:$self,
2225+
TorchIntListType:$size,
2226+
TorchOptionalIntType:$dtype,
2227+
TorchOptionalIntType:$layout,
2228+
TorchOptionalDeviceType:$device,
2229+
TorchOptionalBoolType:$pin_memory
2230+
);
2231+
let results = (outs
2232+
AnyTorchTensorType:$result
2233+
);
2234+
let assemblyFormat = "$self `,` $size `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` qualified(type($self)) `,` qualified(type($size)) `,` qualified(type($dtype)) `,` qualified(type($layout)) `,` qualified(type($device)) `,` qualified(type($pin_memory)) `->` qualified(type($result))";
2235+
}
2236+
21992237
def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
22002238
AllowsTypeRefinement,
22012239
HasValueSemantics

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,21 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
12181218
};
12191219
} // namespace
12201220

1221+
namespace {
1222+
// Decompose constant tensor like ops.
1223+
template <typename OpTy, typename NewOpTy>
1224+
class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
1225+
using OpRewritePattern<OpTy>::OpRewritePattern;
1226+
LogicalResult matchAndRewrite(OpTy op,
1227+
PatternRewriter &rewriter) const override {
1228+
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.size(),
1229+
op.dtype(), op.layout(), op.device(),
1230+
op.pin_memory());
1231+
return success();
1232+
}
1233+
};
1234+
} // namespace
1235+
12211236
namespace {
12221237
class DecomposeComplexOpsPass
12231238
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -1303,6 +1318,12 @@ class DecomposeComplexOpsPass
13031318
target.addIllegalOp<AtenHardsigmoidOp>();
13041319
patterns.add<DecomposeAtenHardswishOp>(context);
13051320
target.addIllegalOp<AtenHardswishOp>();
1321+
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
1322+
context);
1323+
target.addIllegalOp<AtenNewZerosOp>();
1324+
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(
1325+
context);
1326+
target.addIllegalOp<AtenNewOnesOp>();
13061327

13071328
if (failed(applyPartialConversion(getOperation(), target,
13081329
std::move(patterns)))) {

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
388388
} else if (auto emptyLike = dyn_cast<AtenEmptyLikeOp>(op)) {
389389
return visitConstantTensorAllocLikeOp<AtenEmptyLikeOp>(emptyLike,
390390
operands);
391+
} else if (auto newZeros = dyn_cast<AtenNewZerosOp>(op)) {
392+
return visitConstantTensorNewLikeOp<AtenNewZerosOp>(newZeros, operands);
393+
} else if (auto newOnes = dyn_cast<AtenNewOnesOp>(op)) {
394+
return visitConstantTensorNewLikeOp<AtenNewOnesOp>(newOnes, operands);
391395
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
392396
return visitAtenToDtypeOp(toDtype, operands);
393397
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
@@ -582,6 +586,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
582586
template <typename OpTy>
583587
ChangeResult visitConstantTensorAllocLikeOp(
584588
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
589+
template <typename OpTy>
590+
ChangeResult visitConstantTensorNewLikeOp(
591+
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
585592
ChangeResult
586593
visitAtenToDtypeOp(AtenToDtypeOp op,
587594
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@@ -1535,6 +1542,17 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
15351542
return getLatticeElement(op.getResult()).join(knowledge);
15361543
}
15371544

1545+
template <typename OpTy>
1546+
ChangeResult TypeAnalyzer::visitConstantTensorNewLikeOp(
1547+
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1548+
auto input = operands[0]->getValue();
1549+
auto knowledge =
1550+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1551+
fillInSizesGivenSizesList(knowledge, op.size());
1552+
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
1553+
return getLatticeElement(op.getResult()).join(knowledge);
1554+
}
1555+
15381556
// Convert input tensor type to the given `dtype`.
15391557
ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
15401558
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,9 @@ def emit_with_mutating_variants(key, **kwargs):
572572
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
573573
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
574574
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
575+
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
575576
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
577+
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
576578
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
577579
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
578580
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")

0 commit comments

Comments
 (0)