Skip to content

Commit 09b8f2d

Browse files
committed
[CIR][CIRGen] Add complex type and its CIRGen support
This patch adds !cir.complex type to model the _Complex type in C. It also contains support for its CIRGen. In detail, this patch adds the following CIR types, ops, and attributes: - The `!cir.complex` type is added to model the _Complex type in C. This type is parameterized with the type of the components of the complex number, which must be either an integer type or a floating-point type. - The `#cir.complex` attribute is added to represent a literal value of _Complex type. - The `cir.complex.extract` op is added to extract the real and imaginary part of a value of `!cir.complex` type. CIRGen support for the new complex type is also added. Note the implementation diverges from the original clang CodeGen, where expressions of complex types are handled differently from scalars and aggregates. Instead, this patch treats expressions of complex types as scalars, as such expressions can be simply lowered to a CIR value of `!cir.complex` type.
1 parent 8b7417c commit 09b8f2d

22 files changed

+520
-105
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+29
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,35 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
4141
public:
4242
CIRBaseBuilderTy(mlir::MLIRContext &C) : mlir::OpBuilder(&C) {}
4343

44+
mlir::cir::BoolType getBoolTy() {
45+
return ::mlir::cir::BoolType::get(getContext());
46+
}
47+
48+
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
49+
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
50+
}
51+
52+
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
53+
if (ty.isa<mlir::cir::IntType>())
54+
return mlir::cir::IntAttr::get(ty, 0);
55+
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
56+
return mlir::cir::FPAttr::getZero(fltType);
57+
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
58+
return mlir::cir::FPAttr::getZero(fltType);
59+
if (auto complexTy = ty.dyn_cast<mlir::cir::ComplexType>())
60+
return mlir::cir::ComplexAttr::getZero(complexTy);
61+
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
62+
return getZeroAttr(arrTy);
63+
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
64+
return getConstPtrAttr(ptrTy, 0);
65+
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
66+
return getZeroAttr(structTy);
67+
if (ty.isa<mlir::cir::BoolType>()) {
68+
return getCIRBoolAttr(false);
69+
}
70+
llvm_unreachable("Zero initializer for given type is NYI");
71+
}
72+
4473
mlir::Value getConstAPSInt(mlir::Location loc, const llvm::APSInt &val) {
4574
auto ty = mlir::cir::IntType::get(getContext(), val.getBitWidth(),
4675
val.isSigned());

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

+35
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,41 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
242242
}];
243243
}
244244

245+
//===----------------------------------------------------------------------===//
246+
// ComplexAttr
247+
//===----------------------------------------------------------------------===//
248+
249+
def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
250+
let summary = "An attribute containing a complex number value";
251+
let description = [{
252+
A `#cir.complex` attribute is a literal attribute that represents a complex
253+
number value of the specified complex type.
254+
255+
The `real` parameter gives the real part of the complex number, and the
256+
`imag` parameter gives the imaginary part of the complex number.
257+
}];
258+
259+
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "TypedAttr":$real,
260+
"TypedAttr":$imag);
261+
262+
let builders = [
263+
AttrBuilderWithInferredContext<(ins "Type":$type, "TypedAttr":$real,
264+
"TypedAttr":$imag), [{
265+
return $_get(type.getContext(), type, real, imag);
266+
}]>,
267+
];
268+
269+
let extraClassDeclaration = [{
270+
static ComplexAttr getZero(Type type);
271+
}];
272+
273+
let genVerifyDecl = 1;
274+
275+
let assemblyFormat = [{
276+
`<` $real `,` $imag `>`
277+
}];
278+
}
279+
245280
//===----------------------------------------------------------------------===//
246281
// ConstPointerAttr
247282
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

+61-1
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,18 @@ def CK_FloatToBoolean : I32EnumAttrCase<"float_to_bool", 10>;
5757
def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
5858
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
5959
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
60+
def CK_IntegralToComplex : I32EnumAttrCase<"int_to_complex", 14>;
61+
def CK_FloatToComplex : I32EnumAttrCase<"float_to_complex", 15>;
62+
def CK_ComplexCast : I32EnumAttrCase<"complex", 16>;
6063

6164
def CastKind : I32EnumAttr<
6265
"CastKind",
6366
"cast kind",
6467
[CK_IntegralToBoolean, CK_ArrayToPointerDecay, CK_IntegralCast,
6568
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
6669
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
67-
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat]> {
70+
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
71+
CK_IntegralToComplex, CK_FloatToComplex, CK_ComplexCast]> {
6872
let cppNamespace = "::mlir::cir";
6973
}
7074

@@ -86,6 +90,8 @@ def CastOp : CIR_Op<"cast", [Pure]> {
8690
- `ptr_to_bool`
8791
- `bool_to_int`
8892
- `bool_to_float`
93+
- `int_to_complex`
94+
- `float_to_complex`
8995

9096
This is effectively a subset of the rules from
9197
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
@@ -986,6 +992,60 @@ def CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
986992
let hasVerifier = 0;
987993
}
988994

995+
//===----------------------------------------------------------------------===//
996+
// ComplexRealOp and ComplexImagOp
997+
//===----------------------------------------------------------------------===//
998+
999+
def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
1000+
let summary = "Extract the real part of a complex value";
1001+
let description = [{
1002+
`cir.complex.real` operation takes an operand of complex type and returns
1003+
the real part of it.
1004+
1005+
Example:
1006+
1007+
```mlir
1008+
!complex = !cir.complex<!cir.float>
1009+
%0 = cir.const(#cir.complex<#cir.fp<1.0>, #cir.fp<2.0>> : !complex) : !complex
1010+
%1 = cir.complex.real(%0 : !complex) : !cir.float
1011+
```
1012+
}];
1013+
1014+
let results = (outs CIR_AnyType:$result);
1015+
let arguments = (ins CIR_ComplexType:$operand);
1016+
1017+
let assemblyFormat = [{
1018+
`(` $operand `:` qualified(type($operand)) `)` `:` type($result) attr-dict
1019+
}];
1020+
1021+
let hasVerifier = 1;
1022+
}
1023+
1024+
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
1025+
let summary = "Extract the imaginary part of a complex value";
1026+
let description = [{
1027+
`cir.complex.imag` operation takes an operand of complex type and returns
1028+
the imaginary part of it.
1029+
1030+
Example:
1031+
1032+
```mlir
1033+
!complex = !cir.complex<!cir.float>
1034+
%0 = cir.const(#cir.complex<#cir.fp<1.0>, #cir.fp<2.0>> : !complex) : !complex
1035+
%1 = cir.complex.imag(%0 : !complex) : !cir.float
1036+
```
1037+
}];
1038+
1039+
let results = (outs CIR_AnyType:$result);
1040+
let arguments = (ins CIR_ComplexType:$operand);
1041+
1042+
let assemblyFormat = [{
1043+
`(` $operand `:` qualified(type($operand)) `)` `:` type($result) attr-dict
1044+
}];
1045+
1046+
let hasVerifier = 1;
1047+
}
1048+
9891049
//===----------------------------------------------------------------------===//
9901050
// BitsOp
9911051
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+27-1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,32 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
170170

171171
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;
172172

173+
//===----------------------------------------------------------------------===//
174+
// ComplexType
175+
//===----------------------------------------------------------------------===//
176+
177+
def CIR_ComplexType : CIR_Type<"Complex", "complex",
178+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
179+
180+
let summary = "CIR complex type";
181+
let description = [{
182+
CIR type that represents a C/C++ complex number. `cir.complex` models the
183+
C type `_Complex`.
184+
185+
The parameter `elementTy` gives the type of the real and imaginary part of
186+
the complex number. `elementTy` must be either a CIR integer type or a CIR
187+
floating-point type.
188+
}];
189+
190+
let parameters = (ins "mlir::Type":$elementTy);
191+
192+
let assemblyFormat = [{
193+
`<` $elementTy `>`
194+
}];
195+
196+
let genVerifyDecl = 1;
197+
}
198+
173199
//===----------------------------------------------------------------------===//
174200
// PointerType
175201
//===----------------------------------------------------------------------===//
@@ -444,7 +470,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
444470
def CIR_AnyType : AnyTypeOf<[
445471
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
446472
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
447-
CIR_AnyFloat,
473+
CIR_AnyFloat, CIR_ComplexType
448474
]>;
449475

450476
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+18-26
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
140140
return mlir::cir::ZeroAttr::get(getContext(), t);
141141
}
142142

143-
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
144-
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
145-
}
146-
147143
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
148144
assert(t.isa<mlir::cir::PointerType>() && "expected cir.ptr");
149145
return mlir::cir::ConstPtrAttr::get(getContext(), t, 0);
@@ -243,25 +239,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
243239
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
244240
}
245241

246-
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
247-
if (ty.isa<mlir::cir::IntType>())
248-
return mlir::cir::IntAttr::get(ty, 0);
249-
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
250-
return mlir::cir::FPAttr::getZero(fltType);
251-
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
252-
return mlir::cir::FPAttr::getZero(fltType);
253-
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
254-
return getZeroAttr(arrTy);
255-
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
256-
return getConstPtrAttr(ptrTy, 0);
257-
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
258-
return getZeroAttr(structTy);
259-
if (ty.isa<mlir::cir::BoolType>()) {
260-
return getCIRBoolAttr(false);
261-
}
262-
llvm_unreachable("Zero initializer for given type is NYI");
263-
}
264-
265242
// TODO(cir): Once we have CIR float types, replace this by something like a
266243
// NullableValueInterface to allow for type-independent queries.
267244
bool isNullValue(mlir::Attribute attr) const {
@@ -398,9 +375,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
398375
llvm_unreachable("Unknown float format!");
399376
}
400377

401-
mlir::cir::BoolType getBoolTy() {
402-
return ::mlir::cir::BoolType::get(getContext());
403-
}
404378
mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
405379
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
406380
// type so it's a bit more clear and C++ idiomatic.
@@ -888,6 +862,24 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
888862
mlir::Value createPtrIsNull(mlir::Value ptr) {
889863
return createNot(createPtrToBoolCast(ptr));
890864
}
865+
866+
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
867+
auto operandComplexTy = operand.getType().cast<mlir::cir::ComplexType>();
868+
auto resultTy = operandComplexTy.getElementTy();
869+
return create<mlir::cir::ComplexRealOp>(loc, resultTy, operand);
870+
}
871+
872+
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
873+
auto operandComplexTy = operand.getType().cast<mlir::cir::ComplexType>();
874+
auto resultTy = operandComplexTy.getElementTy();
875+
return create<mlir::cir::ComplexImagOp>(loc, resultTy, operand);
876+
}
877+
878+
mlir::Value createComplexIsZero(mlir::Location loc, mlir::Value operand) {
879+
auto zero = getNullValue(operand.getType(), loc);
880+
return create<mlir::cir::CmpOp>(loc, getBoolTy(), mlir::cir::CmpOpKind::eq,
881+
operand, zero);
882+
}
891883
};
892884

893885
} // namespace cir

clang/lib/CIR/CodeGen/CIRGenCXX.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ static void buildDeclInit(CIRGenFunction &CGF, const VarDecl *D,
200200
case TEK_Scalar:
201201
CGF.buildScalarInit(Init, CGF.getLoc(D->getLocation()), lv, false);
202202
return;
203-
case TEK_Complex:
204-
llvm_unreachable("complext evaluation NYI");
205203
}
206204
}
207205

clang/lib/CIR/CodeGen/CIRGenClass.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -804,9 +804,6 @@ void CIRGenFunction::buildInitializerForField(FieldDecl *Field, LValue LHS,
804804
llvm_unreachable("NYI");
805805
}
806806
break;
807-
case TEK_Complex:
808-
llvm_unreachable("NYI");
809-
break;
810807
case TEK_Aggregate: {
811808
AggValueSlot Slot = AggValueSlot::forLValue(
812809
LHS, AggValueSlot::IsDestructed, AggValueSlot::DoesNotNeedGCBarriers,

clang/lib/CIR/CodeGen/CIRGenDecl.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -700,12 +700,9 @@ void CIRGenFunction::buildExprAsInit(const Expr *init, const ValueDecl *D,
700700
return;
701701
}
702702
switch (CIRGenFunction::getEvaluationKind(type)) {
703-
case TEK_Scalar:
703+
case TEK_Scalar: {
704704
buildScalarInit(init, getLoc(D->getSourceRange()), lvalue);
705705
return;
706-
case TEK_Complex: {
707-
assert(0 && "not implemented");
708-
return;
709706
}
710707
case TEK_Aggregate:
711708
assert(!type->isAtomicType() && "NYI");

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

-10
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,6 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
974974
return LV;
975975
}
976976

977-
case TEK_Complex:
978-
assert(0 && "not implemented");
979977
case TEK_Aggregate:
980978
assert(0 && "not implemented");
981979
}
@@ -1065,8 +1063,6 @@ RValue CIRGenFunction::buildAnyExpr(const Expr *E, AggValueSlot aggSlot,
10651063
switch (CIRGenFunction::getEvaluationKind(E->getType())) {
10661064
case TEK_Scalar:
10671065
return RValue::get(buildScalarExpr(E));
1068-
case TEK_Complex:
1069-
assert(0 && "not implemented");
10701066
case TEK_Aggregate: {
10711067
if (!ignoreResult && aggSlot.isIgnored())
10721068
aggSlot = CreateAggTemp(E->getType(), getLoc(E->getSourceRange()),
@@ -1854,10 +1850,6 @@ void CIRGenFunction::buildAnyExprToMem(const Expr *E, Address Location,
18541850
Qualifiers Quals, bool IsInit) {
18551851
// FIXME: This function should take an LValue as an argument.
18561852
switch (getEvaluationKind(E->getType())) {
1857-
case TEK_Complex:
1858-
assert(0 && "NYI");
1859-
return;
1860-
18611853
case TEK_Aggregate: {
18621854
buildAggExpr(E, AggValueSlot::forAddr(Location, Quals,
18631855
AggValueSlot::IsDestructed_t(IsInit),
@@ -2306,8 +2298,6 @@ RValue CIRGenFunction::convertTempToRValue(Address addr, clang::QualType type,
23062298
clang::SourceLocation loc) {
23072299
LValue lvalue = makeAddrLValue(addr, type, AlignmentSource::Decl);
23082300
switch (getEvaluationKind(type)) {
2309-
case TEK_Complex:
2310-
llvm_unreachable("NYI");
23112301
case TEK_Aggregate:
23122302
llvm_unreachable("NYI");
23132303
case TEK_Scalar:

clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -782,9 +782,6 @@ void AggExprEmitter::buildInitializationToLValue(Expr *E, LValue LV) {
782782
}
783783

784784
switch (CGF.getEvaluationKind(type)) {
785-
case TEK_Complex:
786-
llvm_unreachable("NYI");
787-
return;
788785
case TEK_Aggregate:
789786
CGF.buildAggExpr(
790787
E, AggValueSlot::forLValue(LV, AggValueSlot::IsDestructed,

clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,6 @@ static void StoreAnyExprIntoOneUnit(CIRGenFunction &CGF, const Expr *Init,
592592
CGF.buildScalarInit(Init, CGF.getLoc(Init->getSourceRange()),
593593
CGF.makeAddrLValue(NewPtr, AllocType), false);
594594
return;
595-
case TEK_Complex:
596-
llvm_unreachable("NYI");
597-
return;
598595
case TEK_Aggregate: {
599596
AggValueSlot Slot = AggValueSlot::forAddr(
600597
NewPtr, AllocType.getQualifiers(), AggValueSlot::IsDestructed,

0 commit comments

Comments
 (0)