Skip to content

Commit ec94476

Browse files
authored
[CIR][CIRGen] Add complex type and its CIRGen support (#513)
This PR adds `!cir.complex` type to model the `_Complex` type in C. It also contains support for its CIRGen. In detail, this patch adds the following CIR types, ops, and attributes: - The `!cir.complex` type is added to model the `_Complex` type in C. This type is parameterized with the type of the components of the complex number, which must be either an integer type or a floating-point type. - ~The `#cir.complex` attribute is added to represent a literal value of `_Complex` type. It is a struct-like attribute that provides the real and imaginary part of the literal `_Complex` value.~ - ~The `#cir.imag` attribute is added to represent a purely imaginary number.~ - The `cir.complex.create` op is added to create a complex value from its real and imaginary parts. - ~The `cir.complex.real` and `cir.complex.imag` op is added to extract the real and imaginary part of a value of `!cir.complex` type, respectively.~ - The `cir.complex.real_ptr` and `cir.complex.imag_ptr` op is added to derive a pointer to the real and imaginary part of a value of `!cir.complex` type, respectively. CIRGen support for some of the fundamental complex number operations is also included. ~Note the implementation diverges from the original clang CodeGen, where expressions of complex types are handled differently from scalars and aggregates. Instead, this patch treats expressions of complex types as scalars, as such expressions can be simply lowered to a CIR value of `!cir.complex` type.~ This PR addresses #445 .
1 parent f7c508c commit ec94476

16 files changed

+1032
-32
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+29
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,35 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
8989
return getPointerTo(::mlir::cir::VoidType::get(getContext()), langAS);
9090
}
9191

92+
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
93+
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
94+
}
95+
96+
mlir::TypedAttr getZeroAttr(mlir::Type t) {
97+
return mlir::cir::ZeroAttr::get(getContext(), t);
98+
}
99+
100+
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
101+
if (mlir::isa<mlir::cir::IntType>(ty))
102+
return mlir::cir::IntAttr::get(ty, 0);
103+
if (auto fltType = mlir::dyn_cast<mlir::cir::SingleType>(ty))
104+
return mlir::cir::FPAttr::getZero(fltType);
105+
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
106+
return mlir::cir::FPAttr::getZero(fltType);
107+
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
108+
return getZeroAttr(complexType);
109+
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
110+
return getZeroAttr(arrTy);
111+
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
112+
return getConstPtrAttr(ptrTy, 0);
113+
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
114+
return getZeroAttr(structTy);
115+
if (mlir::isa<mlir::cir::BoolType>(ty)) {
116+
return getCIRBoolAttr(false);
117+
}
118+
llvm_unreachable("Zero initializer for given type is NYI");
119+
}
120+
92121
mlir::Value createLoad(mlir::Location loc, mlir::Value ptr,
93122
bool isVolatile = false, uint64_t alignment = 0) {
94123
mlir::IntegerAttr intAttr;

clang/include/clang/CIR/Dialect/IR/CIROps.td

+84
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,90 @@ def BinOpOverflowOp : CIR_Op<"binop.overflow", [Pure, SameTypeOperands]> {
11741174
];
11751175
}
11761176

1177+
//===----------------------------------------------------------------------===//
1178+
// ComplexCreateOp
1179+
//===----------------------------------------------------------------------===//
1180+
1181+
def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
1182+
let summary = "Create a complex value from its real and imaginary parts";
1183+
let description = [{
1184+
`cir.complex.create` operation takes two operands that represent the real
1185+
and imaginary part of a complex number, and yields the complex number.
1186+
1187+
Example:
1188+
1189+
```mlir
1190+
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
1191+
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
1192+
%2 = cir.complex.create %0, %1 : !cir.complex<!cir.double>
1193+
```
1194+
}];
1195+
1196+
let results = (outs CIR_ComplexType:$result);
1197+
let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag);
1198+
1199+
let assemblyFormat = [{
1200+
$real `,` $imag
1201+
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
1202+
}];
1203+
1204+
let hasVerifier = 1;
1205+
}
1206+
1207+
//===----------------------------------------------------------------------===//
1208+
// ComplexRealPtrOp and ComplexImagPtrOp
1209+
//===----------------------------------------------------------------------===//
1210+
1211+
def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
1212+
let summary = "Extract the real part of a complex value";
1213+
let description = [{
1214+
`cir.complex.real_ptr` operation takes a pointer operand that points to a
1215+
complex value of type `!cir.complex` and yields a pointer to the real part
1216+
of the operand.
1217+
1218+
Example:
1219+
1220+
```mlir
1221+
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
1222+
```
1223+
}];
1224+
1225+
let results = (outs PrimitiveIntOrFPPtr:$result);
1226+
let arguments = (ins ComplexPtr:$operand);
1227+
1228+
let assemblyFormat = [{
1229+
$operand `:`
1230+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
1231+
}];
1232+
1233+
let hasVerifier = 1;
1234+
}
1235+
1236+
def ComplexImagPtrOp : CIR_Op<"complex.imag_ptr", [Pure]> {
1237+
let summary = "Extract the imaginary part of a complex value";
1238+
let description = [{
1239+
`cir.complex.imag_ptr` operation takes a pointer operand that points to a
1240+
complex value of type `!cir.complex` and yields a pointer to the imaginary
1241+
part of the operand.
1242+
1243+
Example:
1244+
1245+
```mlir
1246+
%1 = cir.complex.imag_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
1247+
```
1248+
}];
1249+
1250+
let results = (outs PrimitiveIntOrFPPtr:$result);
1251+
let arguments = (ins ComplexPtr:$operand);
1252+
1253+
let assemblyFormat = [{
1254+
$operand `:`
1255+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
1256+
}];
1257+
1258+
let hasVerifier = 1;
1259+
}
1260+
11771261
//===----------------------------------------------------------------------===//
11781262
// BitsOp
11791263
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+35-1
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,32 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
196196
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
197197
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
198198

199+
//===----------------------------------------------------------------------===//
200+
// ComplexType
201+
//===----------------------------------------------------------------------===//
202+
203+
def CIR_ComplexType : CIR_Type<"Complex", "complex",
204+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
205+
206+
let summary = "CIR complex type";
207+
let description = [{
208+
CIR type that represents a C complex number. `cir.complex` models the C type
209+
`T _Complex`.
210+
211+
The parameter `elementTy` gives the type of the real and imaginary part of
212+
the complex number. `elementTy` must be either a CIR integer type or a CIR
213+
floating-point type.
214+
}];
215+
216+
let parameters = (ins "mlir::Type":$elementTy);
217+
218+
let assemblyFormat = [{
219+
`<` $elementTy `>`
220+
}];
221+
222+
let genVerifyDecl = 1;
223+
}
224+
199225
//===----------------------------------------------------------------------===//
200226
// PointerType
201227
//===----------------------------------------------------------------------===//
@@ -441,6 +467,14 @@ def PrimitiveIntOrFPPtr : Type<
441467
]>, "{int,void}*"> {
442468
}
443469

470+
def ComplexPtr : Type<
471+
And<[
472+
CPred<"::mlir::isa<::mlir::cir::PointerType>($_self)">,
473+
CPred<"::mlir::isa<::mlir::cir::ComplexType>("
474+
"::mlir::cast<::mlir::cir::PointerType>($_self).getPointee())">,
475+
]>, "!cir.complex*"> {
476+
}
477+
444478
// Pointer to struct
445479
def StructPtr : Type<
446480
And<[
@@ -516,7 +550,7 @@ def CIR_StructType : Type<CPred<"::mlir::isa<::mlir::cir::StructType>($_self)">,
516550
def CIR_AnyType : AnyTypeOf<[
517551
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
518552
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
519-
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
553+
CIR_AnyFloat, CIR_FP16, CIR_BFloat16, CIR_ComplexType
520554
]>;
521555

522556
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+44-11
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
136136
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
137137
}
138138

139-
mlir::TypedAttr getZeroAttr(mlir::Type t) {
140-
return mlir::cir::ZeroAttr::get(getContext(), t);
141-
}
142-
143-
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
144-
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
145-
}
146-
147139
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
148140
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
149141
return getConstPtrAttr(t, 0);
@@ -265,6 +257,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
265257
return mlir::cir::FPAttr::getZero(fltType);
266258
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
267259
return mlir::cir::FPAttr::getZero(fltType);
260+
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
261+
return getZeroAttr(complexType);
268262
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
269263
return getZeroAttr(arrTy);
270264
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
@@ -764,6 +758,42 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
764758
return create<mlir::cir::GetMemberOp>(loc, result, base, name, index);
765759
}
766760

761+
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
762+
mlir::Value imag) {
763+
auto resultComplexTy =
764+
mlir::cir::ComplexType::get(getContext(), real.getType());
765+
return create<mlir::cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
766+
}
767+
768+
/// Create a cir.complex.real_ptr operation that derives a pointer to the real
769+
/// part of the complex value pointed to by the specified pointer value.
770+
mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
771+
auto srcPtrTy = mlir::cast<mlir::cir::PointerType>(value.getType());
772+
auto srcComplexTy =
773+
mlir::cast<mlir::cir::ComplexType>(srcPtrTy.getPointee());
774+
return create<mlir::cir::ComplexRealPtrOp>(
775+
loc, getPointerTo(srcComplexTy.getElementTy()), value);
776+
}
777+
778+
Address createRealPtr(mlir::Location loc, Address addr) {
779+
return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
780+
}
781+
782+
/// Create a cir.complex.imag_ptr operation that derives a pointer to the
783+
/// imaginary part of the complex value pointed to by the specified pointer
784+
/// value.
785+
mlir::Value createImagPtr(mlir::Location loc, mlir::Value value) {
786+
auto srcPtrTy = mlir::cast<mlir::cir::PointerType>(value.getType());
787+
auto srcComplexTy =
788+
mlir::cast<mlir::cir::ComplexType>(srcPtrTy.getPointee());
789+
return create<mlir::cir::ComplexImagPtrOp>(
790+
loc, getPointerTo(srcComplexTy.getElementTy()), value);
791+
}
792+
793+
Address createImagPtr(mlir::Location loc, Address addr) {
794+
return Address{createImagPtr(loc, addr.getPointer()), addr.getAlignment()};
795+
}
796+
767797
/// Cast the element type of the given address to a different type,
768798
/// preserving information like the alignment.
769799
cir::Address createElementBitCast(mlir::Location loc, cir::Address addr,
@@ -776,15 +806,18 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
776806
addr.getAlignment());
777807
}
778808

779-
mlir::Value createLoad(mlir::Location loc, Address addr) {
809+
mlir::Value createLoad(mlir::Location loc, Address addr,
810+
bool isVolatile = false) {
780811
auto ptrTy =
781812
mlir::dyn_cast<mlir::cir::PointerType>(addr.getPointer().getType());
782813
if (addr.getElementType() != ptrTy.getPointee())
783814
addr = addr.withPointer(
784815
createPtrBitcast(addr.getPointer(), addr.getElementType()));
785816

786-
return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
787-
addr.getPointer());
817+
return create<mlir::cir::LoadOp>(
818+
loc, addr.getElementType(), addr.getPointer(), /*isDeref=*/false,
819+
/*is_volatile=*/isVolatile, /*alignment=*/mlir::IntegerAttr{},
820+
/*mem_order=*/mlir::cir::MemOrderAttr{});
788821
}
789822

790823
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,

clang/lib/CIR/CodeGen/CIRGenDecl.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,11 @@ void CIRGenFunction::buildExprAsInit(const Expr *init, const ValueDecl *D,
734734
buildScalarInit(init, getLoc(D->getSourceRange()), lvalue);
735735
return;
736736
case TEK_Complex: {
737-
assert(0 && "not implemented");
737+
mlir::Value complex = buildComplexExpr(init);
738+
if (capturedByInit)
739+
llvm_unreachable("NYI");
740+
buildStoreOfComplex(getLoc(init->getExprLoc()), complex, lvalue,
741+
/*init*/ true);
738742
return;
739743
}
740744
case TEK_Aggregate:

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
12241224
}
12251225

12261226
case TEK_Complex:
1227-
assert(0 && "not implemented");
1227+
return buildComplexAssignmentLValue(E);
12281228
case TEK_Aggregate:
12291229
assert(0 && "not implemented");
12301230
}
@@ -1264,6 +1264,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
12641264
if (E->getOpcode() == UO_Extension)
12651265
return buildLValue(E->getSubExpr());
12661266

1267+
QualType ExprTy = getContext().getCanonicalType(E->getSubExpr()->getType());
12671268
switch (E->getOpcode()) {
12681269
default:
12691270
llvm_unreachable("Unknown unary operator lvalue!");
@@ -1288,7 +1289,29 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
12881289
}
12891290
case UO_Real:
12901291
case UO_Imag: {
1291-
assert(0 && "not implemented");
1292+
LValue LV = buildLValue(E->getSubExpr());
1293+
assert(LV.isSimple() && "real/imag on non-ordinary l-value");
1294+
1295+
// __real is valid on scalars. This is a faster way of testing that.
1296+
// __imag can only produce an rvalue on scalars.
1297+
if (E->getOpcode() == UO_Real &&
1298+
!mlir::isa<mlir::cir::ComplexType>(LV.getAddress().getElementType())) {
1299+
assert(E->getSubExpr()->getType()->isArithmeticType());
1300+
return LV;
1301+
}
1302+
1303+
QualType T = ExprTy->castAs<clang::ComplexType>()->getElementType();
1304+
1305+
auto Loc = getLoc(E->getExprLoc());
1306+
Address Component =
1307+
(E->getOpcode() == UO_Real
1308+
? buildAddrOfRealComponent(Loc, LV.getAddress(), LV.getType())
1309+
: buildAddrOfImagComponent(Loc, LV.getAddress(), LV.getType()));
1310+
// TODO(cir): TBAA info.
1311+
assert(!MissingFeatures::tbaa());
1312+
LValue ElemLV = makeAddrLValue(Component, T, LV.getBaseInfo());
1313+
ElemLV.getQuals().addQualifiers(LV.getQuals());
1314+
return ElemLV;
12921315
}
12931316
case UO_PreInc:
12941317
case UO_PreDec: {
@@ -1315,7 +1338,7 @@ RValue CIRGenFunction::buildAnyExpr(const Expr *E, AggValueSlot aggSlot,
13151338
case TEK_Scalar:
13161339
return RValue::get(buildScalarExpr(E));
13171340
case TEK_Complex:
1318-
assert(0 && "not implemented");
1341+
return RValue::getComplex(buildComplexExpr(E));
13191342
case TEK_Aggregate: {
13201343
if (!ignoreResult && aggSlot.isIgnored())
13211344
aggSlot = CreateAggTemp(E->getType(), getLoc(E->getSourceRange()),

0 commit comments

Comments
 (0)