Skip to content

Commit 42a6d1d

Browse files
Lancernlanza
authored andcommitted
[CIR] Add support for complex cast operations (#758)
This PR adds support for complex cast operations. It adds the following new cast kind variants to the `cir.cast` operation: - `float_to_complex`, - `int_to_complex`, - `float_complex_to_real`, - `int_complex_to_real`, - `float_complex_to_bool`, - `int_complex_to_bool`, - `float_complex`, - `float_complex_to_int_complex`, - `int_complex`, and - `int_complex_to_float_complex`. CIRGen and LLVM IR support for these new cast variants are also included.
1 parent 3cea5b7 commit 42a6d1d

File tree

8 files changed

+575
-66
lines changed

8 files changed

+575
-66
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+26-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
5959
return create<mlir::cir::ConstantOp>(loc, attr.getType(), attr);
6060
}
6161

62+
// Creates constant null value for integral type ty.
63+
mlir::cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc) {
64+
return create<mlir::cir::ConstantOp>(loc, ty, getZeroInitAttr(ty));
65+
}
66+
67+
mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
68+
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
69+
getCIRBoolAttr(state));
70+
}
71+
mlir::cir::ConstantOp getFalse(mlir::Location loc) {
72+
return getBool(false, loc);
73+
}
74+
mlir::cir::ConstantOp getTrue(mlir::Location loc) {
75+
return getBool(true, loc);
76+
}
77+
6278
mlir::cir::BoolType getBoolTy() {
6379
return ::mlir::cir::BoolType::get(getContext());
6480
}
@@ -110,12 +126,16 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
110126
return mlir::cir::FPAttr::getZero(fltType);
111127
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
112128
return mlir::cir::FPAttr::getZero(fltType);
129+
if (auto fltType = mlir::dyn_cast<mlir::cir::FP16Type>(ty))
130+
return mlir::cir::FPAttr::getZero(fltType);
131+
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
132+
return mlir::cir::FPAttr::getZero(fltType);
113133
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
114134
return getZeroAttr(complexType);
115135
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
116136
return getZeroAttr(arrTy);
117137
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
118-
return getConstPtrAttr(ptrTy, 0);
138+
return getConstNullPtrAttr(ptrTy);
119139
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
120140
return getZeroAttr(structTy);
121141
if (mlir::isa<mlir::cir::BoolType>(ty)) {
@@ -548,6 +568,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
548568
getContext(), mlir::cast<mlir::cir::PointerType>(t), val);
549569
}
550570

571+
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
572+
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
573+
return getConstPtrAttr(t, 0);
574+
}
575+
551576
// Creates constant nullptr for pointer type ty.
552577
mlir::cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
553578
assert(!MissingFeatures::targetCodeGenInfoGetNullPointer());

clang/include/clang/CIR/Dialect/IR/CIROps.td

+27-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ def CK_BooleanToIntegral : I32EnumAttrCase<"bool_to_int", 11>;
7171
def CK_IntegralToFloat : I32EnumAttrCase<"int_to_float", 12>;
7272
def CK_BooleanToFloat : I32EnumAttrCase<"bool_to_float", 13>;
7373
def CK_AddressSpaceConversion : I32EnumAttrCase<"address_space", 14>;
74+
def CK_FloatToComplex : I32EnumAttrCase<"float_to_complex", 15>;
75+
def CK_IntegralToComplex : I32EnumAttrCase<"int_to_complex", 16>;
76+
def CK_FloatComplexToReal : I32EnumAttrCase<"float_complex_to_real", 17>;
77+
def CK_IntegralComplexToReal : I32EnumAttrCase<"int_complex_to_real", 18>;
78+
def CK_FloatComplexToBoolean : I32EnumAttrCase<"float_complex_to_bool", 19>;
79+
def CK_IntegralComplexToBoolean : I32EnumAttrCase<"int_complex_to_bool", 20>;
80+
def CK_FloatComplexCast : I32EnumAttrCase<"float_complex", 21>;
81+
def CK_FloatComplexToIntegralComplex
82+
: I32EnumAttrCase<"float_complex_to_int_complex", 22>;
83+
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
84+
def CK_IntegralComplexToFloatComplex
85+
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
7486

7587
def CastKind : I32EnumAttr<
7688
"CastKind",
@@ -79,7 +91,11 @@ def CastKind : I32EnumAttr<
7991
CK_BitCast, CK_FloatingCast, CK_PtrToBoolean, CK_FloatToIntegral,
8092
CK_IntegralToPointer, CK_PointerToIntegral, CK_FloatToBoolean,
8193
CK_BooleanToIntegral, CK_IntegralToFloat, CK_BooleanToFloat,
82-
CK_AddressSpaceConversion]> {
94+
CK_AddressSpaceConversion, CK_FloatToComplex, CK_IntegralToComplex,
95+
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
96+
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
97+
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
98+
CK_IntegralComplexToFloatComplex]> {
8399
let cppNamespace = "::mlir::cir";
84100
}
85101

@@ -104,6 +120,16 @@ def CastOp : CIR_Op<"cast",
104120
- `bool_to_int`
105121
- `bool_to_float`
106122
- `address_space`
123+
- `float_to_complex`
124+
- `int_to_complex`
125+
- `float_complex_to_real`
126+
- `int_complex_to_real`
127+
- `float_complex_to_bool`
128+
- `int_complex_to_bool`
129+
- `float_complex`
130+
- `float_complex_to_int_complex`
131+
- `int_complex`
132+
- `int_complex_to_float_complex`
107133

108134
This is effectively a subset of the rules from
109135
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some

clang/lib/CIR/CodeGen/CIRGenBuilder.h

-46
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
136136
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
137137
}
138138

139-
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
140-
assert(mlir::isa<mlir::cir::PointerType>(t) && "expected cir.ptr");
141-
return getConstPtrAttr(t, 0);
142-
}
143-
144139
mlir::Attribute getString(llvm::StringRef str, mlir::Type eltTy,
145140
unsigned size = 0) {
146141
unsigned finalSize = size ? size : str.size();
@@ -246,31 +241,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
246241
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
247242
}
248243

249-
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
250-
if (mlir::isa<mlir::cir::IntType>(ty))
251-
return mlir::cir::IntAttr::get(ty, 0);
252-
if (auto fltType = mlir::dyn_cast<mlir::cir::SingleType>(ty))
253-
return mlir::cir::FPAttr::getZero(fltType);
254-
if (auto fltType = mlir::dyn_cast<mlir::cir::DoubleType>(ty))
255-
return mlir::cir::FPAttr::getZero(fltType);
256-
if (auto fltType = mlir::dyn_cast<mlir::cir::FP16Type>(ty))
257-
return mlir::cir::FPAttr::getZero(fltType);
258-
if (auto fltType = mlir::dyn_cast<mlir::cir::BF16Type>(ty))
259-
return mlir::cir::FPAttr::getZero(fltType);
260-
if (auto complexType = mlir::dyn_cast<mlir::cir::ComplexType>(ty))
261-
return getZeroAttr(complexType);
262-
if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(ty))
263-
return getZeroAttr(arrTy);
264-
if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(ty))
265-
return getConstNullPtrAttr(ptrTy);
266-
if (auto structTy = mlir::dyn_cast<mlir::cir::StructType>(ty))
267-
return getZeroAttr(structTy);
268-
if (mlir::isa<mlir::cir::BoolType>(ty)) {
269-
return getCIRBoolAttr(false);
270-
}
271-
llvm_unreachable("Zero initializer for given type is NYI");
272-
}
273-
274244
// TODO(cir): Once we have CIR float types, replace this by something like a
275245
// NullableValueInterface to allow for type-independent queries.
276246
bool isNullValue(mlir::Attribute attr) const {
@@ -554,28 +524,12 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
554524
mlir::cir::IntAttr::get(t, C));
555525
}
556526

557-
mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
558-
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
559-
getCIRBoolAttr(state));
560-
}
561-
mlir::cir::ConstantOp getFalse(mlir::Location loc) {
562-
return getBool(false, loc);
563-
}
564-
mlir::cir::ConstantOp getTrue(mlir::Location loc) {
565-
return getBool(true, loc);
566-
}
567-
568527
/// Create constant nullptr for pointer-to-data-member type ty.
569528
mlir::cir::ConstantOp getNullDataMemberPtr(mlir::cir::DataMemberType ty,
570529
mlir::Location loc) {
571530
return create<mlir::cir::ConstantOp>(loc, ty, getNullDataMemberAttr(ty));
572531
}
573532

574-
// Creates constant null value for integral type ty.
575-
mlir::cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc) {
576-
return create<mlir::cir::ConstantOp>(loc, ty, getZeroInitAttr(ty));
577-
}
578-
579533
mlir::cir::ConstantOp getZero(mlir::Location loc, mlir::Type ty) {
580534
// TODO: dispatch creation for primitive types.
581535
assert((mlir::isa<mlir::cir::StructType>(ty) ||

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

+38-9
Original file line numberDiff line numberDiff line change
@@ -372,20 +372,43 @@ mlir::Value ComplexExprEmitter::buildComplexToComplexCast(mlir::Value Val,
372372
QualType SrcType,
373373
QualType DestType,
374374
SourceLocation Loc) {
375-
// Get the src/dest element type.
376-
SrcType = SrcType->castAs<ComplexType>()->getElementType();
377-
DestType = DestType->castAs<ComplexType>()->getElementType();
378375
if (SrcType == DestType)
379376
return Val;
380377

381-
llvm_unreachable("complex cast is NYI");
378+
// Get the src/dest element type.
379+
QualType SrcElemTy = SrcType->castAs<ComplexType>()->getElementType();
380+
QualType DestElemTy = DestType->castAs<ComplexType>()->getElementType();
381+
382+
mlir::cir::CastKind CastOpKind;
383+
if (SrcElemTy->isFloatingType() && DestElemTy->isFloatingType())
384+
CastOpKind = mlir::cir::CastKind::float_complex;
385+
else if (SrcElemTy->isFloatingType() && DestElemTy->isIntegerType())
386+
CastOpKind = mlir::cir::CastKind::float_complex_to_int_complex;
387+
else if (SrcElemTy->isIntegerType() && DestElemTy->isFloatingType())
388+
CastOpKind = mlir::cir::CastKind::int_complex_to_float_complex;
389+
else if (SrcElemTy->isIntegerType() && DestElemTy->isIntegerType())
390+
CastOpKind = mlir::cir::CastKind::int_complex;
391+
else
392+
llvm_unreachable("unexpected src type or dest type");
393+
394+
return Builder.createCast(CGF.getLoc(Loc), CastOpKind, Val,
395+
CGF.ConvertType(DestType));
382396
}
383397

384398
mlir::Value ComplexExprEmitter::buildScalarToComplexCast(mlir::Value Val,
385399
QualType SrcType,
386400
QualType DestType,
387401
SourceLocation Loc) {
388-
llvm_unreachable("complex cast is NYI");
402+
mlir::cir::CastKind CastOpKind;
403+
if (SrcType->isFloatingType())
404+
CastOpKind = mlir::cir::CastKind::float_to_complex;
405+
else if (SrcType->isIntegerType())
406+
CastOpKind = mlir::cir::CastKind::int_to_complex;
407+
else
408+
llvm_unreachable("unexpected src type");
409+
410+
return Builder.createCast(CGF.getLoc(Loc), CastOpKind, Val,
411+
CGF.ConvertType(DestType));
389412
}
390413

391414
mlir::Value ComplexExprEmitter::buildCast(CastKind CK, Expr *Op,
@@ -467,14 +490,20 @@ mlir::Value ComplexExprEmitter::buildCast(CastKind CK, Expr *Op,
467490
llvm_unreachable("invalid cast kind for complex value");
468491

469492
case CK_FloatingRealToComplex:
470-
case CK_IntegralRealToComplex:
471-
llvm_unreachable("NYI");
493+
case CK_IntegralRealToComplex: {
494+
assert(!MissingFeatures::CGFPOptionsRAII());
495+
return buildScalarToComplexCast(CGF.buildScalarExpr(Op), Op->getType(),
496+
DestTy, Op->getExprLoc());
497+
}
472498

473499
case CK_FloatingComplexCast:
474500
case CK_FloatingComplexToIntegralComplex:
475501
case CK_IntegralComplexCast:
476-
case CK_IntegralComplexToFloatingComplex:
477-
llvm_unreachable("NYI");
502+
case CK_IntegralComplexToFloatingComplex: {
503+
assert(!MissingFeatures::CGFPOptionsRAII());
504+
return buildComplexToComplexCast(Visit(Op), Op->getType(), DestTy,
505+
Op->getExprLoc());
506+
}
478507
}
479508

480509
llvm_unreachable("unknown cast resulting in complex value");

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+31-5
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
113113
return CGF.buildCheckedLValue(E, TCK);
114114
}
115115

116+
mlir::Value buildComplexToScalarConversion(mlir::Location Loc, mlir::Value V,
117+
CastKind Kind, QualType DestTy);
118+
116119
/// Emit a value that corresponds to null for the given type.
117120
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);
118121

@@ -1797,13 +1800,13 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17971800
case CK_MemberPointerToBoolean:
17981801
llvm_unreachable("NYI");
17991802
case CK_FloatingComplexToReal:
1800-
llvm_unreachable("NYI");
18011803
case CK_IntegralComplexToReal:
1802-
llvm_unreachable("NYI");
18031804
case CK_FloatingComplexToBoolean:
1804-
llvm_unreachable("NYI");
1805-
case CK_IntegralComplexToBoolean:
1806-
llvm_unreachable("NYI");
1805+
case CK_IntegralComplexToBoolean: {
1806+
mlir::Value V = CGF.buildComplexExpr(E);
1807+
return buildComplexToScalarConversion(CGF.getLoc(CE->getExprLoc()), V, Kind,
1808+
DestTy);
1809+
}
18071810
case CK_ZeroToOCLOpaqueType:
18081811
llvm_unreachable("NYI");
18091812
case CK_IntToOCLSampler:
@@ -2161,6 +2164,29 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
21612164
return LHSLV;
21622165
}
21632166

2167+
mlir::Value ScalarExprEmitter::buildComplexToScalarConversion(
2168+
mlir::Location Loc, mlir::Value V, CastKind Kind, QualType DestTy) {
2169+
mlir::cir::CastKind CastOpKind;
2170+
switch (Kind) {
2171+
case CK_FloatingComplexToReal:
2172+
CastOpKind = mlir::cir::CastKind::float_complex_to_real;
2173+
break;
2174+
case CK_IntegralComplexToReal:
2175+
CastOpKind = mlir::cir::CastKind::int_complex_to_real;
2176+
break;
2177+
case CK_FloatingComplexToBoolean:
2178+
CastOpKind = mlir::cir::CastKind::float_complex_to_bool;
2179+
break;
2180+
case CK_IntegralComplexToBoolean:
2181+
CastOpKind = mlir::cir::CastKind::int_complex_to_bool;
2182+
break;
2183+
default:
2184+
llvm_unreachable("invalid complex-to-scalar cast kind");
2185+
}
2186+
2187+
return Builder.createCast(Loc, CastOpKind, V, CGF.ConvertType(DestTy));
2188+
}
2189+
21642190
mlir::Value ScalarExprEmitter::buildNullValue(QualType Ty, mlir::Location loc) {
21652191
return CGF.buildFromMemory(CGF.CGM.buildNullConstant(Ty, loc), Ty);
21662192
}

0 commit comments

Comments
 (0)