Skip to content

Commit aa6fe48

Browse files
authored
[CIR][Dialect] Extend UnaryFPToFPBuiltinOp to vector of FP type (#1132)
1 parent 8176d88 commit aa6fe48

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+2-2
Original file line numberDiff line numberDiff line change
@@ -4342,8 +4342,8 @@ def LLrintOp : UnaryFPToIntBuiltinOp<"llrint", "LlrintOp">;
43424342

43434343
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
43444344
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
4345-
let arguments = (ins CIR_AnyFloat:$src);
4346-
let results = (outs CIR_AnyFloat:$result);
4345+
let arguments = (ins CIR_AnyFloatOrVecOfFloat:$src);
4346+
let results = (outs CIR_AnyFloatOrVecOfFloat:$result);
43474347
let summary = "libc builtin equivalent ignoring "
43484348
"floating point exceptions and errno";
43494349
let assemblyFormat = "$src `:` type($src) attr-dict";

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+10
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,20 @@ def SignedIntegerVector : Type<
553553
]>, "!cir.vector of !cir.int"> {
554554
}
555555

556+
// Vector of Float type
557+
def FPVector : Type<
558+
And<[
559+
CPred<"::mlir::isa<::cir::VectorType>($_self)">,
560+
CPred<"::mlir::isa<::cir::SingleType, ::cir::DoubleType>("
561+
"::mlir::cast<::cir::VectorType>($_self).getEltType())">,
562+
]>, "!cir.vector of !cir.fp"> {
563+
}
564+
556565
// Constraints
557566
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
558567
def CIR_AnySignedIntOrVecOfSignedInt: AnyTypeOf<
559568
[PrimitiveSInt, SignedIntegerVector]>;
569+
def CIR_AnyFloatOrVecOfFloat: AnyTypeOf<[CIR_AnyFloat, FPVector]>;
560570

561571
// Pointer to Arrays
562572
def ArrayPtr : Type<

clang/test/CIR/Lowering/builtin-floating-point.cir

+85-1
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,133 @@
22
// RUN: FileCheck --input-file=%t.ll %s
33

44
module {
5-
cir.func @test(%arg0 : !cir.float) {
5+
cir.func @test(%arg0 : !cir.float, %arg1 : !cir.vector<!cir.double x 2>, %arg2 : !cir.vector<!cir.float x 4>) {
66
%1 = cir.cos %arg0 : !cir.float
77
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32
8+
9+
%101 = cir.cos %arg1 : !cir.vector<!cir.double x 2>
10+
// CHECK: llvm.intr.cos(%arg1) : (vector<2xf64>) -> vector<2xf64>
811

12+
%201 = cir.cos %arg2 : !cir.vector<!cir.float x 4>
13+
// CHECK: llvm.intr.cos(%arg2) : (vector<4xf32>) -> vector<4xf32>
14+
915
%2 = cir.ceil %arg0 : !cir.float
1016
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32
1117

18+
%102 = cir.ceil %arg1 : !cir.vector<!cir.double x 2>
19+
// CHECK: llvm.intr.ceil(%arg1) : (vector<2xf64>) -> vector<2xf64>
20+
21+
%202 = cir.ceil %arg2 : !cir.vector<!cir.float x 4>
22+
// CHECK: llvm.intr.ceil(%arg2) : (vector<4xf32>) -> vector<4xf32>
23+
1224
%3 = cir.exp %arg0 : !cir.float
1325
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32
1426

27+
%103 = cir.exp %arg1 : !cir.vector<!cir.double x 2>
28+
// CHECK: llvm.intr.exp(%arg1) : (vector<2xf64>) -> vector<2xf64>
29+
30+
%203 = cir.exp %arg2 : !cir.vector<!cir.float x 4>
31+
// CHECK: llvm.intr.exp(%arg2) : (vector<4xf32>) -> vector<4xf32>
32+
1533
%4 = cir.exp2 %arg0 : !cir.float
1634
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32
1735

36+
%104 = cir.exp2 %arg1 : !cir.vector<!cir.double x 2>
37+
// CHECK: llvm.intr.exp2(%arg1) : (vector<2xf64>) -> vector<2xf64>
38+
39+
%204 = cir.exp2 %arg2 : !cir.vector<!cir.float x 4>
40+
// CHECK: llvm.intr.exp2(%arg2) : (vector<4xf32>) -> vector<4xf32>
41+
1842
%5 = cir.fabs %arg0 : !cir.float
1943
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32
2044

45+
%105 = cir.fabs %arg1 : !cir.vector<!cir.double x 2>
46+
// CHECK: llvm.intr.fabs(%arg1) : (vector<2xf64>) -> vector<2xf64>
47+
48+
%205 = cir.fabs %arg2 : !cir.vector<!cir.float x 4>
49+
// CHECK: llvm.intr.fabs(%arg2) : (vector<4xf32>) -> vector<4xf32>
50+
2151
%6 = cir.floor %arg0 : !cir.float
2252
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32
2353

54+
%106 = cir.floor %arg1 : !cir.vector<!cir.double x 2>
55+
// CHECK: llvm.intr.floor(%arg1) : (vector<2xf64>) -> vector<2xf64>
56+
57+
%206 = cir.floor %arg2 : !cir.vector<!cir.float x 4>
58+
// CHECK: llvm.intr.floor(%arg2) : (vector<4xf32>) -> vector<4xf32>
59+
2460
%7 = cir.log %arg0 : !cir.float
2561
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32
2662

63+
%107 = cir.log %arg1 : !cir.vector<!cir.double x 2>
64+
// CHECK: llvm.intr.log(%arg1) : (vector<2xf64>) -> vector<2xf64>
65+
66+
%207 = cir.log %arg2 : !cir.vector<!cir.float x 4>
67+
// CHECK: llvm.intr.log(%arg2) : (vector<4xf32>) -> vector<4xf32>
68+
2769
%8 = cir.log10 %arg0 : !cir.float
2870
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32
2971

72+
%108 = cir.log10 %arg1 : !cir.vector<!cir.double x 2>
73+
// CHECK: llvm.intr.log10(%arg1) : (vector<2xf64>) -> vector<2xf64>
74+
75+
%208 = cir.log10 %arg2 : !cir.vector<!cir.float x 4>
76+
// CHECK: llvm.intr.log10(%arg2) : (vector<4xf32>) -> vector<4xf32>
77+
3078
%9 = cir.log2 %arg0 : !cir.float
3179
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32
3280

81+
%109 = cir.log2 %arg1 : !cir.vector<!cir.double x 2>
82+
// CHECK: llvm.intr.log2(%arg1) : (vector<2xf64>) -> vector<2xf64>
83+
84+
%209 = cir.log2 %arg2 : !cir.vector<!cir.float x 4>
85+
// CHECK: llvm.intr.log2(%arg2) : (vector<4xf32>) -> vector<4xf32>
86+
3387
%10 = cir.nearbyint %arg0 : !cir.float
3488
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32
3589

90+
%110 = cir.nearbyint %arg1 : !cir.vector<!cir.double x 2>
91+
// CHECK: llvm.intr.nearbyint(%arg1) : (vector<2xf64>) -> vector<2xf64>
92+
93+
%210 = cir.nearbyint %arg2 : !cir.vector<!cir.float x 4>
94+
// CHECK: llvm.intr.nearbyint(%arg2) : (vector<4xf32>) -> vector<4xf32>
95+
3696
%11 = cir.rint %arg0 : !cir.float
3797
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32
3898

99+
%111 = cir.rint %arg1 : !cir.vector<!cir.double x 2>
100+
// CHECK: llvm.intr.rint(%arg1) : (vector<2xf64>) -> vector<2xf64>
101+
102+
%211 = cir.rint %arg2 : !cir.vector<!cir.float x 4>
103+
// CHECK: llvm.intr.rint(%arg2) : (vector<4xf32>) -> vector<4xf32>
104+
39105
%12 = cir.round %arg0 : !cir.float
40106
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32
41107

108+
%112 = cir.round %arg1 : !cir.vector<!cir.double x 2>
109+
// CHECK: llvm.intr.round(%arg1) : (vector<2xf64>) -> vector<2xf64>
110+
111+
%212 = cir.round %arg2 : !cir.vector<!cir.float x 4>
112+
// CHECK: llvm.intr.round(%arg2) : (vector<4xf32>) -> vector<4xf32>
113+
42114
%13 = cir.sin %arg0 : !cir.float
43115
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32
44116

117+
%113 = cir.sin %arg1 : !cir.vector<!cir.double x 2>
118+
// CHECK: llvm.intr.sin(%arg1) : (vector<2xf64>) -> vector<2xf64>
119+
120+
%213 = cir.sin %arg2 : !cir.vector<!cir.float x 4>
121+
// CHECK: llvm.intr.sin(%arg2) : (vector<4xf32>) -> vector<4xf32>
122+
45123
%14 = cir.sqrt %arg0 : !cir.float
46124
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32
47125

126+
%114 = cir.sqrt %arg1 : !cir.vector<!cir.double x 2>
127+
// CHECK: llvm.intr.sqrt(%arg1) : (vector<2xf64>) -> vector<2xf64>
128+
129+
%214 = cir.sqrt %arg2 : !cir.vector<!cir.float x 4>
130+
// CHECK: llvm.intr.sqrt(%arg2) : (vector<4xf32>) -> vector<4xf32>
131+
48132
cir.return
49133
}
50134
}

0 commit comments

Comments
 (0)