Skip to content

Commit 79b9cf9

Browse files
authored
Add lowering for aten.to.device (#1107)
1 parent b8d51a7 commit 79b9cf9

File tree

9 files changed

+82
-1
lines changed

9 files changed

+82
-1
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"TModuleRank1_basic",
5454
"TModuleRank0_basic",
5555
"ElementwiseToDtypeIdentityModule_basic",
56+
"AtenToDeviceModule_basic",
5657
"View1DFoldModule_basic",
5758
"UnsafeView1DFoldModule_basic",
5859
"SqueezeDimModule_static",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5880,6 +5880,33 @@ def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [
58805880
}];
58815881
}
58825882

5883+
def Torch_AtenToDeviceOp : Torch_Op<"aten.to.device", [
5884+
AllowsTypeRefinement,
5885+
ReadOnly
5886+
]> {
5887+
let summary = "Generated op for `aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)`";
5888+
let arguments = (ins
5889+
AnyTorchTensorType:$self,
5890+
Torch_DeviceType:$device,
5891+
Torch_IntType:$dtype,
5892+
Torch_BoolType:$non_blocking,
5893+
Torch_BoolType:$copy,
5894+
AnyTorchOptionalIntType:$memory_format
5895+
);
5896+
let results = (outs
5897+
AnyTorchTensorType:$result
5898+
);
5899+
let hasCustomAssemblyFormat = 1;
5900+
let extraClassDefinition = [{
5901+
ParseResult AtenToDeviceOp::parse(OpAsmParser &parser, OperationState &result) {
5902+
return parseDefaultTorchOp(parser, result, 6, 1);
5903+
}
5904+
void AtenToDeviceOp::print(OpAsmPrinter &printer) {
5905+
printDefaultTorchOp(printer, *this, 6, 1);
5906+
}
5907+
}];
5908+
}
5909+
58835910
def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
58845911
AllowsTypeRefinement,
58855912
HasValueSemantics,

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,25 @@ class DecomposeAtenToDtypeLayoutOp
19961996
};
19971997
} // namespace
19981998

1999+
namespace {
2000+
// Decompose `aten.to.device` op into `aten.to.dtype` op.
2001+
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
2002+
public:
2003+
using OpRewritePattern::OpRewritePattern;
2004+
LogicalResult matchAndRewrite(AtenToDeviceOp op,
2005+
PatternRewriter &rewriter) const override {
2006+
2007+
// Device information isn't relevant to torch-mlir, so we can drop that info
2008+
// here.
2009+
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.self(),
2010+
op.dtype(), op.non_blocking(),
2011+
op.copy(), op.memory_format());
2012+
2013+
return success();
2014+
}
2015+
};
2016+
} // namespace
2017+
19992018
namespace {
20002019
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
20012020
//
@@ -2586,6 +2605,8 @@ class DecomposeComplexOpsPass
25862605
patterns.add<DecomposeAtenPadOp>(context);
25872606
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
25882607
target.addIllegalOp<AtenToDtypeLayoutOp>();
2608+
patterns.add<DecomposeAtenToDeviceOp>(context);
2609+
target.addIllegalOp<AtenToDeviceOp>();
25892610
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
25902611
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
25912612
patterns.add<DecomposeAtenClampMinOp>(context);

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static bool isViewLikeOp(Operation *op) {
3838
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
3939
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
4040
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
41-
AtenNarrowOp>(op);
41+
AtenNarrowOp, AtenToDeviceOp>(op);
4242
}
4343

4444
namespace {

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,11 @@ void TypeAnalysis::visitOperation(Operation *op,
10241024
return;
10251025
}
10261026

1027+
if (auto toDtype = dyn_cast<AtenToDeviceOp>(op)) {
1028+
visitAtenToDtypeLikeOp<AtenToDeviceOp>(toDtype, operands);
1029+
return;
1030+
}
1031+
10271032
if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
10281033
visitTypeConversionOp<AtenToOtherOp>(toOther, operands);
10291034
return;

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5448,6 +5448,10 @@ module {
54485448
func.func @"__torch_mlir_shape_fn.aten.to.dtype_layout"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional<int>) -> !torch.list<int> {
54495449
return %arg0 : !torch.list<int>
54505450
}
5451+
func.func @"__torch_mlir_shape_fn.aten.to.device"(%arg0: !torch.list<int>, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional<int>) -> !torch.list<int> {
5452+
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
5453+
return %0 : !torch.list<int>
5454+
}
54515455
func.func @"__torch_mlir_shape_fn.aten.to.other"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
54525456
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
54535457
return %0 : !torch.list<int>

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ def aten〇to〇dtype(self: List[int], dtype: int, non_blocking: bool = False, c
427427
def aten〇to〇dtype_layout(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]:
428428
return self
429429

430+
def aten〇to〇device(self: List[int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]:
431+
return upstream_shape_functions.unary(self)
432+
430433
def aten〇to〇other(self: List[int], other: List[int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]:
431434
return upstream_shape_functions.unary(self)
432435

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def emit_with_mutating_variants(key, **kwargs):
456456
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True)
457457
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
458458
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
459+
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
459460
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
460461
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
461462
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")

python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2627,3 +2627,22 @@ def Aten_EmbeddingBagExample_basic(module, tu: TestUtils):
26272627
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
26282628
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
26292629
module.forward(weight, indices, offsets)
2630+
2631+
# ==============================================================================
2632+
2633+
class AtenToDeviceModule(torch.nn.Module):
2634+
def __init__(self):
2635+
super().__init__()
2636+
2637+
@export
2638+
@annotate_args([
2639+
None,
2640+
([-1 , -1], torch.float32, True),
2641+
])
2642+
2643+
def forward(self, val):
2644+
return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False)
2645+
2646+
@register_test_case(module_factory=lambda: AtenToDeviceModule())
2647+
def AtenToDeviceModule_basic(module, tu: TestUtils):
2648+
module.forward(torch.randn(2, 4))

0 commit comments

Comments
 (0)