Skip to content

Commit 5a9db89

Browse files
Lancernlanza
authored andcommitted
[CIR][LLVMLowering] Add LLVM lowering for complex operations (#723)
This PR adds LLVM lowering for the following operations related to complex numbers: - `cir.complex.create`, - `cir.complex.real_ptr`, and - `cir.complex.imag_ptr`. The LLVM IR generated for `cir.complex.create` is a bit ugly since it includes the `insertvalue` instruction, which typically is not generated in upstream CodeGen. Later we may need further CIR canonicalization passes to try folding `cir.complex.create`.
1 parent c593881 commit 5a9db89

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,80 @@ class CIRGetGlobalOpLowering
16341634
}
16351635
};
16361636

1637+
class CIRComplexCreateOpLowering
1638+
: public mlir::OpConversionPattern<mlir::cir::ComplexCreateOp> {
1639+
public:
1640+
using OpConversionPattern<mlir::cir::ComplexCreateOp>::OpConversionPattern;
1641+
1642+
mlir::LogicalResult
1643+
matchAndRewrite(mlir::cir::ComplexCreateOp op, OpAdaptor adaptor,
1644+
mlir::ConversionPatternRewriter &rewriter) const override {
1645+
auto complexLLVMTy =
1646+
getTypeConverter()->convertType(op.getResult().getType());
1647+
auto initialComplex =
1648+
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
1649+
1650+
int64_t position[1]{0};
1651+
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
1652+
op->getLoc(), initialComplex, adaptor.getReal(), position);
1653+
1654+
position[0] = 1;
1655+
auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
1656+
op->getLoc(), realComplex, adaptor.getImag(), position);
1657+
1658+
rewriter.replaceOp(op, complex);
1659+
return mlir::success();
1660+
}
1661+
};
1662+
1663+
class CIRComplexRealPtrOPLowering
1664+
: public mlir::OpConversionPattern<mlir::cir::ComplexRealPtrOp> {
1665+
public:
1666+
using OpConversionPattern<mlir::cir::ComplexRealPtrOp>::OpConversionPattern;
1667+
1668+
mlir::LogicalResult
1669+
matchAndRewrite(mlir::cir::ComplexRealPtrOp op, OpAdaptor adaptor,
1670+
mlir::ConversionPatternRewriter &rewriter) const override {
1671+
auto operandTy =
1672+
mlir::cast<mlir::cir::PointerType>(op.getOperand().getType());
1673+
auto resultLLVMTy =
1674+
getTypeConverter()->convertType(op.getResult().getType());
1675+
auto elementLLVMTy =
1676+
getTypeConverter()->convertType(operandTy.getPointee());
1677+
1678+
mlir::LLVM::GEPArg gepIndices[2]{{0}, {0}};
1679+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1680+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
1681+
/*inbounds=*/true);
1682+
1683+
return mlir::success();
1684+
}
1685+
};
1686+
1687+
class CIRComplexImagPtrOpLowering
1688+
: public mlir::OpConversionPattern<mlir::cir::ComplexImagPtrOp> {
1689+
public:
1690+
using OpConversionPattern<mlir::cir::ComplexImagPtrOp>::OpConversionPattern;
1691+
1692+
mlir::LogicalResult
1693+
matchAndRewrite(mlir::cir::ComplexImagPtrOp op, OpAdaptor adaptor,
1694+
mlir::ConversionPatternRewriter &rewriter) const override {
1695+
auto operandTy =
1696+
mlir::cast<mlir::cir::PointerType>(op.getOperand().getType());
1697+
auto resultLLVMTy =
1698+
getTypeConverter()->convertType(op.getResult().getType());
1699+
auto elementLLVMTy =
1700+
getTypeConverter()->convertType(operandTy.getPointee());
1701+
1702+
mlir::LLVM::GEPArg gepIndices[2]{{0}, {1}};
1703+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1704+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
1705+
/*inbounds=*/true);
1706+
1707+
return mlir::success();
1708+
}
1709+
};
1710+
16371711
class CIRSwitchFlatOpLowering
16381712
: public mlir::OpConversionPattern<mlir::cir::SwitchFlatOp> {
16391713
public:
@@ -3366,9 +3440,10 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
33663440
CIRUnaryOpLowering, CIRBinOpLowering, CIRBinOpOverflowOpLowering,
33673441
CIRShiftOpLowering, CIRLoadLowering, CIRConstantLowering,
33683442
CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering, CIRCastOpLowering,
3369-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
3370-
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
3371-
CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
3443+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering,
3444+
CIRComplexRealPtrOPLowering, CIRComplexImagPtrOpLowering,
3445+
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
3446+
CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering,
33723447
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
33733448
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
33743449
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
@@ -3445,6 +3520,14 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
34453520
converter.addConversion([&](mlir::cir::BF16Type type) -> mlir::Type {
34463521
return mlir::Float16Type::get(type.getContext());
34473522
});
3523+
converter.addConversion([&](mlir::cir::ComplexType type) -> mlir::Type {
3524+
// A complex type is lowered to an LLVM struct that contains the real and
3525+
// imaginary part as data fields.
3526+
mlir::Type elementTy = converter.convertType(type.getElementTy());
3527+
mlir::Type structFields[2] = {elementTy, elementTy};
3528+
return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
3529+
structFields);
3530+
});
34483531
converter.addConversion([&](mlir::cir::FuncType type) -> mlir::Type {
34493532
auto result = converter.convertType(type.getReturnType());
34503533
llvm::SmallVector<mlir::Type> arguments;

clang/test/CIR/CodeGen/complex.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
// RUN: FileCheck --input-file=%t.cir --check-prefixes=C,CHECK %s
33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -x c++ -fclangir -emit-cir -o %t.cir %s
44
// RUN: FileCheck --input-file=%t.cir --check-prefixes=CPP,CHECK %s
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm -o %t.ll %s
6+
// RUN: FileCheck --input-file=%t.ll --check-prefixes=LLVM %s
7+
// XFAIL: *
58

69
double _Complex c, c2;
710
int _Complex ci, ci2;
@@ -24,6 +27,10 @@ void list_init() {
2427
// CHECK-NEXT: %{{.+}} = cir.complex.create %[[#REAL]], %[[#IMAG]] : !s32i -> !cir.complex<!s32i>
2528
// CHECK: }
2629

30+
// LLVM: define void @list_init()
31+
// LLVM: store { double, double } { double 1.000000e+00, double 2.000000e+00 }, ptr %{{.+}}, align 8
32+
// LLVM: }
33+
2734
void list_init_2(double r, double i) {
2835
double _Complex c1 = {r, i};
2936
}
@@ -36,6 +43,12 @@ void list_init_2(double r, double i) {
3643
// CHECK-NEXT: cir.store %[[#C]], %{{.+}} : !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>
3744
// CHECK: }
3845

46+
// LLVM: define void @list_init_2(double %{{.+}}, double %{{.+}})
47+
// LLVM: %[[#A:]] = insertvalue { double, double } undef, double %{{.+}}, 0
48+
// LLVM-NEXT: %[[#B:]] = insertvalue { double, double } %[[#A]], double %{{.+}}, 1
49+
// LLVM-NEXT: store { double, double } %[[#B]], ptr %5, align 8
50+
// LLVM: }
51+
3952
void imag_literal() {
4053
c = 3.0i;
4154
ci = 3i;
@@ -51,6 +64,11 @@ void imag_literal() {
5164
// CHECK-NEXT: %{{.+}} = cir.complex.create %[[#REAL]], %[[#IMAG]] : !s32i -> !cir.complex<!s32i>
5265
// CHECK: }
5366

67+
// LLVM: define void @imag_literal()
68+
// LLVM: store { double, double } { double 0.000000e+00, double 3.000000e+00 }, ptr @c, align 8
69+
// LLVM: store { i32, i32 } { i32 0, i32 3 }, ptr @ci, align 4
70+
// LLVM: }
71+
5472
void load_store() {
5573
c = c2;
5674
ci = ci2;
@@ -68,6 +86,13 @@ void load_store() {
6886
// CHECK-NEXT: cir.store %[[#CI2]], %[[#CI_PTR]] : !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>
6987
// CHECK: }
7088

89+
// LLVM: define void @load_store()
90+
// LLVM: %[[#A:]] = load { double, double }, ptr @c2, align 8
91+
// LLVM-NEXT: store { double, double } %[[#A]], ptr @c, align 8
92+
// LLVM-NEXT: %[[#B:]] = load { i32, i32 }, ptr @ci2, align 4
93+
// LLVM-NEXT: store { i32, i32 } %[[#B]], ptr @ci, align 4
94+
// LLVM: }
95+
7196
void load_store_volatile() {
7297
vc = vc2;
7398
vci = vci2;
@@ -85,6 +110,13 @@ void load_store_volatile() {
85110
// CHECK-NEXT: cir.store volatile %[[#VCI2]], %[[#VCI_PTR]] : !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>
86111
// CHECK: }
87112

113+
// LLVM: define void @load_store_volatile()
114+
// LLVM: %[[#A:]] = load volatile { double, double }, ptr @vc2, align 8
115+
// LLVM-NEXT: store volatile { double, double } %[[#A]], ptr @vc, align 8
116+
// LLVM-NEXT: %[[#B:]] = load volatile { i32, i32 }, ptr @vci2, align 4
117+
// LLVM-NEXT: store volatile { i32, i32 } %[[#B]], ptr @vci, align 4
118+
// LLVM: }
119+
88120
void real_ptr() {
89121
double *r1 = &__real__ c;
90122
int *r2 = &__real__ ci;
@@ -98,6 +130,11 @@ void real_ptr() {
98130
// CHECK-NEXT: %{{.+}} = cir.complex.real_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
99131
// CHECK: }
100132

133+
// LLVM: define void @real_ptr()
134+
// LLVM: store ptr @c, ptr %{{.+}}, align 8
135+
// LLVM-NEXT: store ptr @ci, ptr %{{.+}}, align 8
136+
// LLVM: }
137+
101138
void real_ptr_local() {
102139
double _Complex c1 = {1.0, 2.0};
103140
double *r3 = &__real__ c1;
@@ -109,6 +146,11 @@ void real_ptr_local() {
109146
// CHECK: %{{.+}} = cir.complex.real_ptr %[[#C]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
110147
// CHECK: }
111148

149+
// LLVM: define void @real_ptr_local()
150+
// LLVM: store { double, double } { double 1.000000e+00, double 2.000000e+00 }, ptr %{{.+}}, align 8
151+
// LLVM-NEXT: %{{.+}} = getelementptr inbounds { double, double }, ptr %{{.+}}, i32 0, i32 0
152+
// LLVM: }
153+
112154
void extract_real() {
113155
double r1 = __real__ c;
114156
int r2 = __real__ ci;
@@ -124,6 +166,11 @@ void extract_real() {
124166
// CHECK-NEXT: %{{.+}} = cir.load %[[#REAL_PTR]] : !cir.ptr<!s32i>, !s32i
125167
// CHECK: }
126168

169+
// LLVM: define void @extract_real()
170+
// LLVM: %{{.+}} = load double, ptr @c, align 8
171+
// LLVM: %{{.+}} = load i32, ptr @ci, align 4
172+
// LLVM: }
173+
127174
void imag_ptr() {
128175
double *i1 = &__imag__ c;
129176
int *i2 = &__imag__ ci;
@@ -137,6 +184,11 @@ void imag_ptr() {
137184
// CHECK-NEXT: %{{.+}} = cir.complex.imag_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
138185
// CHECK: }
139186

187+
// LLVM: define void @imag_ptr()
188+
// LLVM: store ptr getelementptr inbounds ({ double, double }, ptr @c, i32 0, i32 1), ptr %{{.+}}, align 8
189+
// LLVM: store ptr getelementptr inbounds ({ i32, i32 }, ptr @ci, i32 0, i32 1), ptr %{{.+}}, align 8
190+
// LLVM: }
191+
140192
void extract_imag() {
141193
double i1 = __imag__ c;
142194
int i2 = __imag__ ci;
@@ -151,3 +203,8 @@ void extract_imag() {
151203
// CHECK-NEXT: %[[#IMAG_PTR:]] = cir.complex.imag_ptr %[[#CI_PTR]] : !cir.ptr<!cir.complex<!s32i>> -> !cir.ptr<!s32i>
152204
// CHECK-NEXT: %{{.+}} = cir.load %[[#IMAG_PTR]] : !cir.ptr<!s32i>, !s32i
153205
// CHECK: }
206+
207+
// LLVM: define void @extract_imag()
208+
// LLVM: %{{.+}} = load double, ptr getelementptr inbounds ({ double, double }, ptr @c, i32 0, i32 1), align 8
209+
// LLVM: %{{.+}} = load i32, ptr getelementptr inbounds ({ i32, i32 }, ptr @ci, i32 0, i32 1), align 4
210+
// LLVM: }

0 commit comments

Comments
 (0)