Skip to content

Commit 055750f

Browse files
committed
[CIR] Add initial support for complex types
This patch adds an initial support for the C complex type, i.e. `_Complex`. It introduces the following new types, attributes, and operations: - `!cir.complex`, which represents the C complex number type; - `cir.complex.create`, which creates a complex number from its real and imaginary parts; - `cir.complex.real_ptr`, which derives a pointer to the real part of a complex number given a pointer to the complex number; - `cir.complex.imag_ptr`, which derives a pointer to the imaginary part of a complex number given a pointer to the complex number. CIRGen for some basic complex number operations is also included in this patch.
1 parent 2dd4609 commit 055750f

17 files changed

+1043
-39
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+29
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
8585
return getPointerTo(::mlir::cir::VoidType::get(getContext()), addressSpace);
8686
}
8787

88+
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
89+
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
90+
}
91+
92+
mlir::TypedAttr getZeroAttr(mlir::Type t) {
93+
return mlir::cir::ZeroAttr::get(getContext(), t);
94+
}
95+
96+
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
97+
if (ty.isa<mlir::cir::IntType>())
98+
return mlir::cir::IntAttr::get(ty, 0);
99+
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
100+
return mlir::cir::FPAttr::getZero(fltType);
101+
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
102+
return mlir::cir::FPAttr::getZero(fltType);
103+
if (auto complexType = ty.dyn_cast<mlir::cir::ComplexType>())
104+
return getZeroAttr(complexType);
105+
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
106+
return getZeroAttr(arrTy);
107+
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
108+
return getConstPtrAttr(ptrTy, 0);
109+
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
110+
return getZeroAttr(structTy);
111+
if (ty.isa<mlir::cir::BoolType>()) {
112+
return getCIRBoolAttr(false);
113+
}
114+
llvm_unreachable("Zero initializer for given type is NYI");
115+
}
116+
88117
mlir::Value createLoad(mlir::Location loc, mlir::Value ptr,
89118
bool isVolatile = false, uint64_t alignment = 0) {
90119
mlir::IntegerAttr intAttr;

clang/include/clang/CIR/Dialect/IR/CIROps.td

+84
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,90 @@ def BinOpOverflowOp : CIR_Op<"binop.overflow", [Pure, SameTypeOperands]> {
11741174
];
11751175
}
11761176

1177+
//===----------------------------------------------------------------------===//
1178+
// ComplexCreateOp
1179+
//===----------------------------------------------------------------------===//
1180+
1181+
def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
1182+
let summary = "Create a complex value from its real and imaginary parts";
1183+
let description = [{
1184+
`cir.complex.create` operation takes two operands that represent the real
1185+
and imaginary part of a complex number, and yields the complex number.
1186+
1187+
Example:
1188+
1189+
```mlir
1190+
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
1191+
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
1192+
%2 = cir.complex.create %0, %1 : !cir.complex<!cir.double>
1193+
```
1194+
}];
1195+
1196+
let results = (outs CIR_ComplexType:$result);
1197+
let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag);
1198+
1199+
let assemblyFormat = [{
1200+
$real `,` $imag
1201+
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
1202+
}];
1203+
1204+
let hasVerifier = 1;
1205+
}
1206+
1207+
//===----------------------------------------------------------------------===//
1208+
// ComplexRealPtrOp and ComplexImagPtrOp
1209+
//===----------------------------------------------------------------------===//
1210+
1211+
def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
1212+
let summary = "Extract the real part of a complex value";
1213+
let description = [{
1214+
`cir.complex.real_ptr` operation takes a pointer operand that points to a
1215+
complex value of type `!cir.complex` and yields a pointer to the real part
1216+
of the operand.
1217+
1218+
Example:
1219+
1220+
```mlir
1221+
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
1222+
```
1223+
}];
1224+
1225+
let results = (outs PrimitiveIntOrFPPtr:$result);
1226+
let arguments = (ins ComplexPtr:$operand);
1227+
1228+
let assemblyFormat = [{
1229+
$operand `:`
1230+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
1231+
}];
1232+
1233+
let hasVerifier = 1;
1234+
}
1235+
1236+
def ComplexImagPtrOp : CIR_Op<"complex.imag_ptr", [Pure]> {
1237+
let summary = "Extract the imaginary part of a complex value";
1238+
let description = [{
1239+
`cir.complex.imag_ptr` operation takes a pointer operand that points to a
1240+
complex value of type `!cir.complex` and yields a pointer to the imaginary
1241+
part of the operand.
1242+
1243+
Example:
1244+
1245+
```mlir
1246+
%1 = cir.complex.imag_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
1247+
```
1248+
}];
1249+
1250+
let results = (outs PrimitiveIntOrFPPtr:$result);
1251+
let arguments = (ins ComplexPtr:$operand);
1252+
1253+
let assemblyFormat = [{
1254+
$operand `:`
1255+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
1256+
}];
1257+
1258+
let hasVerifier = 1;
1259+
}
1260+
11771261
//===----------------------------------------------------------------------===//
11781262
// BitsOp
11791263
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+35-1
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,32 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
196196
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
197197
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
198198

199+
//===----------------------------------------------------------------------===//
200+
// ComplexType
201+
//===----------------------------------------------------------------------===//
202+
203+
def CIR_ComplexType : CIR_Type<"Complex", "complex",
204+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
205+
206+
let summary = "CIR complex type";
207+
let description = [{
208+
CIR type that represents a C complex number. `cir.complex` models the C type
209+
`T _Complex`.
210+
211+
The parameter `elementTy` gives the type of the real and imaginary part of
212+
the complex number. `elementTy` must be either a CIR integer type or a CIR
213+
floating-point type.
214+
}];
215+
216+
let parameters = (ins "mlir::Type":$elementTy);
217+
218+
let assemblyFormat = [{
219+
`<` $elementTy `>`
220+
}];
221+
222+
let genVerifyDecl = 1;
223+
}
224+
199225
//===----------------------------------------------------------------------===//
200226
// PointerType
201227
//===----------------------------------------------------------------------===//
@@ -431,6 +457,14 @@ def PrimitiveIntOrFPPtr : Type<
431457
]>, "{int,void}*"> {
432458
}
433459

460+
def ComplexPtr : Type<
461+
And<[
462+
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
463+
CPred<"$_self.cast<::mlir::cir::PointerType>()"
464+
".getPointee().isa<::mlir::cir::ComplexType>()">,
465+
]>, "!cir.complex*"> {
466+
}
467+
434468
// Pointer to struct
435469
def StructPtr : Type<
436470
And<[
@@ -505,7 +539,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
505539
def CIR_AnyType : AnyTypeOf<[
506540
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
507541
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
508-
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
542+
CIR_AnyFloat, CIR_FP16, CIR_BFloat16, CIR_ComplexType
509543
]>;
510544

511545
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+48-11
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
136136
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
137137
}
138138

139-
mlir::TypedAttr getZeroAttr(mlir::Type t) {
140-
return mlir::cir::ZeroAttr::get(getContext(), t);
141-
}
142-
143-
mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
144-
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
145-
}
146-
147139
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
148140
assert(t.isa<mlir::cir::PointerType>() && "expected cir.ptr");
149141
return getConstPtrAttr(t, 0);
@@ -265,6 +257,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
265257
return mlir::cir::FPAttr::getZero(fltType);
266258
if (auto fltType = ty.dyn_cast<mlir::cir::BF16Type>())
267259
return mlir::cir::FPAttr::getZero(fltType);
260+
if (auto complexType = ty.dyn_cast<mlir::cir::ComplexType>())
261+
return getZeroAttr(complexType);
268262
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
269263
return getZeroAttr(arrTy);
270264
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
@@ -763,6 +757,46 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
763757
return create<mlir::cir::GetMemberOp>(loc, result, base, name, index);
764758
}
765759

760+
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
761+
mlir::Value imag) {
762+
auto resultComplexTy =
763+
mlir::cir::ComplexType::get(getContext(), real.getType());
764+
return create<mlir::cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
765+
}
766+
767+
/// Create a cir.complex.real_ptr operation that derives a pointer to the real
768+
/// part of the complex value pointed to by the specified pointer value.
769+
mlir::Value createRealPtr(mlir::Location loc, mlir::Value value) {
770+
auto srcComplexElemTy = value.getType()
771+
.cast<mlir::cir::PointerType>()
772+
.getPointee()
773+
.cast<mlir::cir::ComplexType>()
774+
.getElementTy();
775+
return create<mlir::cir::ComplexRealPtrOp>(
776+
loc, getPointerTo(srcComplexElemTy), value);
777+
}
778+
779+
Address createRealPtr(mlir::Location loc, Address addr) {
780+
return Address{createRealPtr(loc, addr.getPointer()), addr.getAlignment()};
781+
}
782+
783+
/// Create a cir.complex.imag_ptr operation that derives a pointer to the
784+
/// imaginary part of the complex value pointed to by the specified pointer
785+
/// value.
786+
mlir::Value createImagPtr(mlir::Location loc, mlir::Value value) {
787+
auto srcComplexElemTy = value.getType()
788+
.cast<mlir::cir::PointerType>()
789+
.getPointee()
790+
.cast<mlir::cir::ComplexType>()
791+
.getElementTy();
792+
return create<mlir::cir::ComplexImagPtrOp>(
793+
loc, getPointerTo(srcComplexElemTy), value);
794+
}
795+
796+
Address createImagPtr(mlir::Location loc, Address addr) {
797+
return Address{createImagPtr(loc, addr.getPointer()), addr.getAlignment()};
798+
}
799+
766800
/// Cast the element type of the given address to a different type,
767801
/// preserving information like the alignment.
768802
cir::Address createElementBitCast(mlir::Location loc, cir::Address addr,
@@ -775,14 +809,17 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
775809
addr.getAlignment());
776810
}
777811

778-
mlir::Value createLoad(mlir::Location loc, Address addr) {
812+
mlir::Value createLoad(mlir::Location loc, Address addr,
813+
bool isVolatile = false) {
779814
auto ptrTy = addr.getPointer().getType().dyn_cast<mlir::cir::PointerType>();
780815
if (addr.getElementType() != ptrTy.getPointee())
781816
addr = addr.withPointer(
782817
createPtrBitcast(addr.getPointer(), addr.getElementType()));
783818

784-
return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
785-
addr.getPointer());
819+
return create<mlir::cir::LoadOp>(
820+
loc, addr.getElementType(), addr.getPointer(), /*isDeref=*/false,
821+
/*is_volatile=*/isVolatile, /*alignment=*/mlir::IntegerAttr{},
822+
/*mem_order=*/mlir::cir::MemOrderAttr{});
786823
}
787824

788825
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,

clang/lib/CIR/CodeGen/CIRGenDecl.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,11 @@ void CIRGenFunction::buildExprAsInit(const Expr *init, const ValueDecl *D,
733733
buildScalarInit(init, getLoc(D->getSourceRange()), lvalue);
734734
return;
735735
case TEK_Complex: {
736-
assert(0 && "not implemented");
736+
mlir::Value complex = buildComplexExpr(init);
737+
if (capturedByInit)
738+
llvm_unreachable("NYI");
739+
buildStoreOfComplex(getLoc(init->getExprLoc()), complex, lvalue,
740+
/*init*/ true);
737741
return;
738742
}
739743
case TEK_Aggregate:

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
12201220
}
12211221

12221222
case TEK_Complex:
1223-
assert(0 && "not implemented");
1223+
return buildComplexAssignmentLValue(E);
12241224
case TEK_Aggregate:
12251225
assert(0 && "not implemented");
12261226
}
@@ -1260,6 +1260,7 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
12601260
if (E->getOpcode() == UO_Extension)
12611261
return buildLValue(E->getSubExpr());
12621262

1263+
QualType ExprTy = getContext().getCanonicalType(E->getSubExpr()->getType());
12631264
switch (E->getOpcode()) {
12641265
default:
12651266
llvm_unreachable("Unknown unary operator lvalue!");
@@ -1284,7 +1285,29 @@ LValue CIRGenFunction::buildUnaryOpLValue(const UnaryOperator *E) {
12841285
}
12851286
case UO_Real:
12861287
case UO_Imag: {
1287-
assert(0 && "not implemented");
1288+
LValue LV = buildLValue(E->getSubExpr());
1289+
assert(LV.isSimple() && "real/imag on non-ordinary l-value");
1290+
1291+
// __real is valid on scalars. This is a faster way of testing that.
1292+
// __imag can only produce an rvalue on scalars.
1293+
if (E->getOpcode() == UO_Real &&
1294+
!LV.getAddress().getElementType().isa<mlir::cir::ComplexType>()) {
1295+
assert(E->getSubExpr()->getType()->isArithmeticType());
1296+
return LV;
1297+
}
1298+
1299+
QualType T = ExprTy->castAs<clang::ComplexType>()->getElementType();
1300+
1301+
auto Loc = getLoc(E->getExprLoc());
1302+
Address Component =
1303+
(E->getOpcode() == UO_Real
1304+
? buildAddrOfRealComponent(Loc, LV.getAddress(), LV.getType())
1305+
: buildAddrOfImagComponent(Loc, LV.getAddress(), LV.getType()));
1306+
// TODO(cir): TBAA info.
1307+
assert(!MissingFeatures::tbaa());
1308+
LValue ElemLV = makeAddrLValue(Component, T, LV.getBaseInfo());
1309+
ElemLV.getQuals().addQualifiers(LV.getQuals());
1310+
return ElemLV;
12881311
}
12891312
case UO_PreInc:
12901313
case UO_PreDec: {
@@ -1311,7 +1334,7 @@ RValue CIRGenFunction::buildAnyExpr(const Expr *E, AggValueSlot aggSlot,
13111334
case TEK_Scalar:
13121335
return RValue::get(buildScalarExpr(E));
13131336
case TEK_Complex:
1314-
assert(0 && "not implemented");
1337+
return RValue::getComplex(buildComplexExpr(E));
13151338
case TEK_Aggregate: {
13161339
if (!ignoreResult && aggSlot.isIgnored())
13171340
aggSlot = CreateAggTemp(E->getType(), getLoc(E->getSourceRange()),

0 commit comments

Comments
 (0)