Skip to content

Commit 59d506f

Browse files
committed
[CIR][CIRGen] Add complex type and its CIRGen support
This patch adds !cir.complex type to model the _Complex type in C. It also contains support for its CIRGen. In detail, this patch adds the following CIR types, ops, and attributes: - The `!cir.complex` type is added to model the _Complex type in C. This type is parameterized with the type of the components of the complex number, which must be either an integer type or a floating-point type. - The `#cir.complex` attribute is added to represent a literal value of _Complex type. - The `cir.complex.create` op is added, which creates a complex value from its real and imaginary parts. - The `cir.complex.real` and `cir.complex.imag` ops are added to extract the real and imaginary part of a value of `!cir.complex` type. CIRGen support for the new complex type is also added.
1 parent f228348 commit 59d506f

21 files changed

+1746
-64
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+33
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,39 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
4141
public:
4242
CIRBaseBuilderTy(mlir::MLIRContext &C) : mlir::OpBuilder(&C) {}
4343

44+
mlir::cir::BoolType getBoolTy() {
45+
return ::mlir::cir::BoolType::get(getContext());
46+
}
47+
48+
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
49+
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
50+
}
51+
52+
mlir::TypedAttr getZeroAttr(mlir::Type t) {
53+
return mlir::cir::ZeroAttr::get(getContext(), t);
54+
}
55+
56+
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
57+
if (ty.isa<mlir::cir::IntType>())
58+
return mlir::cir::IntAttr::get(ty, 0);
59+
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
60+
return mlir::cir::FPAttr::getZero(fltType);
61+
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
62+
return mlir::cir::FPAttr::getZero(fltType);
63+
if (auto complexTy = ty.dyn_cast<mlir::cir::ComplexType>())
64+
return mlir::cir::ComplexAttr::getZero(complexTy);
65+
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
66+
return getZeroAttr(arrTy);
67+
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
68+
return getConstPtrAttr(ptrTy, 0);
69+
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
70+
return getZeroAttr(structTy);
71+
if (ty.isa<mlir::cir::BoolType>()) {
72+
return getCIRBoolAttr(false);
73+
}
74+
llvm_unreachable("Zero initializer for given type is NYI");
75+
}
76+
4477
mlir::Value getConstAPSInt(mlir::Location loc, const llvm::APSInt &val) {
4578
auto ty = mlir::cir::IntType::get(getContext(), val.getBitWidth(),
4679
val.isSigned());

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

+35
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,41 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
242242
}];
243243
}
244244

245+
//===----------------------------------------------------------------------===//
246+
// ComplexAttr
247+
//===----------------------------------------------------------------------===//
248+
249+
def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
250+
let summary = "An attribute containing a complex number value";
251+
let description = [{
252+
A `#cir.complex` attribute is a literal attribute that represents a complex
253+
number value of the specified complex type.
254+
255+
The `real` parameter gives the real part of the complex number, and the
256+
`imag` parameter gives the imaginary part of the complex number.
257+
}];
258+
259+
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "TypedAttr":$real,
260+
"TypedAttr":$imag);
261+
262+
let builders = [
263+
AttrBuilderWithInferredContext<(ins "Type":$type, "TypedAttr":$real,
264+
"TypedAttr":$imag), [{
265+
return $_get(type.getContext(), type, real, imag);
266+
}]>,
267+
];
268+
269+
let extraClassDeclaration = [{
270+
static ComplexAttr getZero(Type type);
271+
}];
272+
273+
let genVerifyDecl = 1;
274+
275+
let assemblyFormat = [{
276+
`<` $real `,` $imag `>`
277+
}];
278+
}
279+
245280
//===----------------------------------------------------------------------===//
246281
// ConstPointerAttr
247282
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

+101-1
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,18 @@ def CK_FloatToBoolean : I32EnumAttrCase<"float_to_bool", 10>;
6868
def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
6969
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
7070
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
71+
def CK_IntegralToComplex : I32EnumAttrCase<"int_to_complex", 14>;
72+
def CK_FloatToComplex : I32EnumAttrCase<"float_to_complex", 15>;
73+
def CK_ComplexCast : I32EnumAttrCase<"complex", 16>;
7174

7275
def CastKind : I32EnumAttr<
7376
"CastKind",
7477
"cast kind",
7578
[CK_IntegralToBoolean, CK_ArrayToPointerDecay, CK_IntegralCast,
7679
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
7780
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
78-
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat]> {
81+
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
82+
CK_IntegralToComplex, CK_FloatToComplex, CK_ComplexCast]> {
7983
let cppNamespace = "::mlir::cir";
8084
}
8185

@@ -97,6 +101,8 @@ def CastOp : CIR_Op<"cast", [Pure]> {
97101
- `ptr_to_bool`
98102
- `bool_to_int`
99103
- `bool_to_float`
104+
- `int_to_complex`
105+
- `float_to_complex`
100106

101107
This is effectively a subset of the rules from
102108
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
@@ -823,6 +829,7 @@ def UnaryOpKind_Dec : I32EnumAttrCase<"Dec", 2, "dec">;
823829
def UnaryOpKind_Plus : I32EnumAttrCase<"Plus", 3, "plus">;
824830
def UnaryOpKind_Minus : I32EnumAttrCase<"Minus", 4, "minus">;
825831
def UnaryOpKind_Not : I32EnumAttrCase<"Not", 5, "not">;
832+
def UnaryOpKind_Conjugate : I32EnumAttrCase<"Conjugate", 6, "conjugate">;
826833

827834
def UnaryOpKind : I32EnumAttr<
828835
"UnaryOpKind",
@@ -832,6 +839,7 @@ def UnaryOpKind : I32EnumAttr<
832839
UnaryOpKind_Plus,
833840
UnaryOpKind_Minus,
834841
UnaryOpKind_Not,
842+
UnaryOpKind_Conjugate,
835843
]> {
836844
let cppNamespace = "::mlir::cir";
837845
}
@@ -995,6 +1003,98 @@ def CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
9951003
let hasVerifier = 0;
9961004
}
9971005

1006+
//===----------------------------------------------------------------------===//
1007+
// ComplexCreateOp
1008+
//===----------------------------------------------------------------------===//
1009+
1010+
def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
1011+
let summary = "Create a new complex value from its real and imaginary parts";
1012+
let description = [{
1013+
The `cir.complex.create` operation takes two operands of the same type and
1014+
returns a value of `!cir.complex` type. The real and imaginary part of the
1015+
returned complex value is specified by the operands.
1016+
1017+
The element type of the returned complex value is the same as the type of
1018+
the operand. The type of the two operands must be either an integer type or
1019+
a floating-point type.
1020+
1021+
Example:
1022+
1023+
```mlir
1024+
!u32i = !cir.int<u, 32>
1025+
!complex = !cir.complex<!u32i>
1026+
1027+
%0 = cir.const(#cir.int<1> : !u32i) : !u32i
1028+
%1 = cir.const(#cir.int<2> : !u32i) : !u32i
1029+
%2 = cir.complex.create(%0 : !u32i, %1) : !complex
1030+
```
1031+
}];
1032+
1033+
let results = (outs CIR_ComplexType:$result);
1034+
let arguments = (ins CIR_AnyType:$real, CIR_AnyType:$imag);
1035+
1036+
let assemblyFormat = [{
1037+
`(` $real `:` qualified(type($real)) `,` $imag `)`
1038+
`:` type($result) attr-dict
1039+
}];
1040+
1041+
let hasVerifier = 1;
1042+
}
1043+
1044+
//===----------------------------------------------------------------------===//
1045+
// ComplexRealOp and ComplexImagOp
1046+
//===----------------------------------------------------------------------===//
1047+
1048+
def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
1049+
let summary = "Extract the real part of a complex value";
1050+
let description = [{
1051+
`cir.complex.real` operation takes an operand of complex type and returns
1052+
the real part of it.
1053+
1054+
Example:
1055+
1056+
```mlir
1057+
!complex = !cir.complex<!cir.float>
1058+
%0 = cir.const(#cir.complex<#cir.fp<1.0>, #cir.fp<2.0>> : !complex) : !complex
1059+
%1 = cir.complex.real(%0 : !complex) : !cir.float
1060+
```
1061+
}];
1062+
1063+
let results = (outs CIR_AnyType:$result);
1064+
let arguments = (ins CIR_ComplexType:$operand);
1065+
1066+
let assemblyFormat = [{
1067+
`(` $operand `:` qualified(type($operand)) `)` `:` type($result) attr-dict
1068+
}];
1069+
1070+
let hasVerifier = 1;
1071+
}
1072+
1073+
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
1074+
let summary = "Extract the imaginary part of a complex value";
1075+
let description = [{
1076+
`cir.complex.imag` operation takes an operand of complex type and returns
1077+
the imaginary part of it.
1078+
1079+
Example:
1080+
1081+
```mlir
1082+
!complex = !cir.complex<!cir.float>
1083+
%0 = cir.const(#cir.complex<#cir.fp<1.0>, #cir.fp<2.0>> : !complex) : !complex
1084+
%1 = cir.complex.imag(%0 : !complex) : !cir.float
1085+
```
1086+
}];
1087+
1088+
let results = (outs CIR_AnyType:$result);
1089+
let arguments = (ins CIR_ComplexType:$operand);
1090+
1091+
let assemblyFormat = [{
1092+
`(` $operand `:` qualified(type($operand)) `)` `:` type($result) attr-dict
1093+
}];
1094+
1095+
let hasVerifier = 1;
1096+
}
1097+
9981098
//===----------------------------------------------------------------------===//
9991099
// BitsOp
10001100
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+27-1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,32 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
171171
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;
172172
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
173173

174+
//===----------------------------------------------------------------------===//
175+
// ComplexType
176+
//===----------------------------------------------------------------------===//
177+
178+
def CIR_ComplexType : CIR_Type<"Complex", "complex",
179+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
180+
181+
let summary = "CIR complex type";
182+
let description = [{
183+
CIR type that represents a C/C++ complex number. `cir.complex` models the
184+
C type `_Complex`.
185+
186+
The parameter `elementTy` gives the type of the real and imaginary part of
187+
the complex number. `elementTy` must be either a CIR integer type or a CIR
188+
floating-point type.
189+
}];
190+
191+
let parameters = (ins "mlir::Type":$elementTy);
192+
193+
let assemblyFormat = [{
194+
`<` $elementTy `>`
195+
}];
196+
197+
let genVerifyDecl = 1;
198+
}
199+
174200
//===----------------------------------------------------------------------===//
175201
// PointerType
176202
//===----------------------------------------------------------------------===//
@@ -455,7 +481,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
455481
def CIR_AnyType : AnyTypeOf<[
456482
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
457483
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
458-
CIR_AnyFloat,
484+
CIR_AnyFloat, CIR_ComplexType
459485
]>;
460486

461487
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+33-30
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
136136
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
137137
}
138138

139-
mlir::TypedAttr getZeroAttr(mlir::Type t) {
140-
return mlir::cir::ZeroAttr::get(getContext(), t);
141-
}
142-
143-
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
144-
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
145-
}
146-
147139
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
148140
assert(t.isa<mlir::cir::PointerType>() && "expected cir.ptr");
149141
return mlir::cir::ConstPtrAttr::get(getContext(), t, 0);
@@ -243,25 +235,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
243235
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
244236
}
245237

246-
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
247-
if (ty.isa<mlir::cir::IntType>())
248-
return mlir::cir::IntAttr::get(ty, 0);
249-
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
250-
return mlir::cir::FPAttr::getZero(fltType);
251-
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
252-
return mlir::cir::FPAttr::getZero(fltType);
253-
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
254-
return getZeroAttr(arrTy);
255-
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
256-
return getConstPtrAttr(ptrTy, 0);
257-
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
258-
return getZeroAttr(structTy);
259-
if (ty.isa<mlir::cir::BoolType>()) {
260-
return getCIRBoolAttr(false);
261-
}
262-
llvm_unreachable("Zero initializer for given type is NYI");
263-
}
264-
265238
// TODO(cir): Once we have CIR float types, replace this by something like a
266239
// NullableValueInterface to allow for type-independent queries.
267240
bool isNullValue(mlir::Attribute attr) const {
@@ -398,9 +371,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
398371
llvm_unreachable("Unknown float format!");
399372
}
400373

401-
mlir::cir::BoolType getBoolTy() {
402-
return ::mlir::cir::BoolType::get(getContext());
403-
}
404374
mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
405375
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
406376
// type so it's a bit more clear and C++ idiomatic.
@@ -748,6 +718,12 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
748718
createElementBitCast(loc, addr, ptrTy.getPointee()).getPointer());
749719
}
750720

721+
mlir::Value createVolatileLoad(mlir::Location loc, Address addr) {
722+
return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
723+
addr.getPointer(), /*isDeref=*/false,
724+
/*is_volatile=*/true);
725+
}
726+
751727
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,
752728
mlir::Value ptr,
753729
[[maybe_unused]] llvm::MaybeAlign align,
@@ -890,6 +866,33 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
890866
mlir::Value createPtrIsNull(mlir::Value ptr) {
891867
return createNot(createPtrToBoolCast(ptr));
892868
}
869+
870+
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
871+
mlir::Value imag) {
872+
assert(real.getType() == imag.getType() &&
873+
"operands to cir.complex.create must have the same type");
874+
875+
auto complexTy = mlir::cir::ComplexType::get(getContext(), real.getType());
876+
return create<mlir::cir::ComplexCreateOp>(loc, complexTy, real, imag);
877+
}
878+
879+
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
880+
auto operandComplexTy = operand.getType().cast<mlir::cir::ComplexType>();
881+
auto resultTy = operandComplexTy.getElementTy();
882+
return create<mlir::cir::ComplexRealOp>(loc, resultTy, operand);
883+
}
884+
885+
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
886+
auto operandComplexTy = operand.getType().cast<mlir::cir::ComplexType>();
887+
auto resultTy = operandComplexTy.getElementTy();
888+
return create<mlir::cir::ComplexImagOp>(loc, resultTy, operand);
889+
}
890+
891+
mlir::Value createComplexIsZero(mlir::Location loc, mlir::Value operand) {
892+
auto zero = getNullValue(operand.getType(), loc);
893+
return create<mlir::cir::CmpOp>(loc, getBoolTy(), mlir::cir::CmpOpKind::eq,
894+
operand, zero);
895+
}
893896
};
894897

895898
} // namespace cir

clang/lib/CIR/CodeGen/CIRGenCXX.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ static void buildDeclInit(CIRGenFunction &CGF, const VarDecl *D,
201201
CGF.buildScalarInit(Init, CGF.getLoc(D->getLocation()), lv, false);
202202
return;
203203
case TEK_Complex:
204-
llvm_unreachable("complext evaluation NYI");
204+
CGF.buildComplexInit(Init, CGF.getLoc(D->getLocation()), lv);
205+
return;
205206
}
206207
}
207208

clang/lib/CIR/CodeGen/CIRGenClass.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ void CIRGenFunction::buildInitializerForField(FieldDecl *Field, LValue LHS,
805805
}
806806
break;
807807
case TEK_Complex:
808-
llvm_unreachable("NYI");
808+
buildComplexExprIntoLValue(Init, LHS, /*isInit*/ true);
809809
break;
810810
case TEK_Aggregate: {
811811
AggValueSlot Slot = AggValueSlot::forLValue(

0 commit comments

Comments
 (0)