Skip to content

Commit 99fb4c8

Browse files
authored
Add folder for ToF64Op and FromF64Op (#1257)
1 parent ba17a4d commit 99fb4c8

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def TorchConversion_ToF64Op : TorchConversion_Op<"to_f64", [
154154
let assemblyFormat = [{
155155
$operand attr-dict
156156
}];
157+
let hasFolder = 1;
157158
}
158159

159160
def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
@@ -172,6 +173,7 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
172173
let assemblyFormat = [{
173174
$operand attr-dict
174175
}];
176+
let hasFolder = 1;
175177
}
176178

177179
def TorchConversion_I64ToGeneratorOp : TorchConversion_Op<"i64_to_generator", [

lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,31 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
9797
}
9898
}
9999

100+
//===----------------------------------------------------------------------===//
101+
// ToF64Op
102+
//===----------------------------------------------------------------------===//
103+
104+
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
105+
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
106+
if (attr) {
107+
return attr;
108+
} else {
109+
return nullptr;
110+
}
111+
}
112+
113+
//===----------------------------------------------------------------------===//
114+
// FromF64Op
115+
//===----------------------------------------------------------------------===//
116+
117+
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
118+
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
119+
if (attr) {
120+
return attr;
121+
} else {
122+
return nullptr;
123+
}
124+
}
125+
100126
#define GET_OP_CLASSES
101127
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"

test/Dialect/TorchConversion/canonicalize.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,41 @@ func.func @torch_c.to_i64$from_i64() -> !torch.int {
3737
%1 = torch_c.from_i64 %0
3838
return %1 : !torch.int
3939
}
40+
41+
// CHECK-LABEL: func.func @torch_c.from_f64() -> !torch.float {
42+
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
43+
// CHECK: return %[[FLOAT5]] : !torch.float
44+
func.func @torch_c.from_f64() -> !torch.float {
45+
%c5_f64 = arith.constant 5.000000e+00 : f64
46+
%0 = torch_c.from_f64 %c5_f64
47+
return %0 : !torch.float
48+
}
49+
50+
// CHECK-LABEL: func.func @torch_c.to_f64() -> f64 {
51+
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
52+
// CHECK: return %[[C5_f64]] : f64
53+
func.func @torch_c.to_f64() -> f64 {
54+
%float5 = torch.constant.float 5.000000e+00
55+
%0 = torch_c.to_f64 %float5
56+
return %0 : f64
57+
}
58+
59+
// CHECK-LABEL: func.func @torch_c.from_f64$to_f64() -> f64 {
60+
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
61+
// CHECK: return %[[C5_f64]] : f64
62+
func.func @torch_c.from_f64$to_f64() -> f64 {
63+
%c5_f64 = arith.constant 5.000000e+00 : f64
64+
%0 = torch_c.from_f64 %c5_f64
65+
%1 = torch_c.to_f64 %0
66+
return %1 : f64
67+
}
68+
69+
// CHECK-LABEL: func.func @torch_c.to_f64$from_f64() -> !torch.float {
70+
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
71+
// CHECK: return %[[FLOAT5]] : !torch.float
72+
func.func @torch_c.to_f64$from_f64() -> !torch.float {
73+
%float5 = torch.constant.float 5.000000e+00
74+
%0 = torch_c.to_f64 %float5
75+
%1 = torch_c.from_f64 %0
76+
return %1 : !torch.float
77+
}

0 commit comments

Comments
 (0)