Skip to content

Commit 8acb296

Browse files
committed
[CIR][Transforms] Add folders for complex operations
This patch adds folderss for `cir.complex.create`, `cir.complex.real`, and `cir.complex.imag`. This patch adds a new attribute `#cir.complex` that represents a constant complex value. Besides, the CIR dialect does not have a constant materializer yet; this patch adds it.
1 parent e136285 commit 8acb296

File tree

10 files changed

+373
-106
lines changed

10 files changed

+373
-106
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,40 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
273273
}];
274274
}
275275

276+
//===----------------------------------------------------------------------===//
277+
// ComplexAttr
278+
//===----------------------------------------------------------------------===//
279+
280+
def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
281+
let summary = "An attribute that contains a constant complex value";
282+
let description = [{
283+
The `#cir.complex` attribute contains a constant value of complex number
284+
type. The `real` parameter gives the real part of the complex number and the
285+
`imag` parameter gives the imaginary part of the complex number.
286+
287+
The `real` and `imag` parameter must be either an IntAttr or an FPAttr that
288+
contains values of the same CIR type.
289+
}];
290+
291+
let parameters = (ins
292+
AttributeSelfTypeParameter<"", "mlir::cir::ComplexType">:$type,
293+
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
294+
295+
let builders = [
296+
AttrBuilderWithInferredContext<(ins "mlir::cir::ComplexType":$type,
297+
"mlir::TypedAttr":$real,
298+
"mlir::TypedAttr":$imag), [{
299+
return $_get(type.getContext(), type, real, imag);
300+
}]>,
301+
];
302+
303+
let genVerifyDecl = 1;
304+
305+
let assemblyFormat = [{
306+
`<` qualified($real) `,` qualified($imag) `>`
307+
}];
308+
}
309+
276310
//===----------------------------------------------------------------------===//
277311
// ConstPointerAttr
278312
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
2727
let useDefaultAttributePrinterParser = 0;
2828
let useDefaultTypePrinterParser = 0;
2929

30+
let hasConstantMaterializer = 1;
31+
3032
let extraClassDeclaration = [{
3133

3234
// Names of CIR parameter attributes.

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
12091209
}];
12101210

12111211
let hasVerifier = 1;
1212+
let hasFolder = 1;
12121213
}
12131214

12141215
//===----------------------------------------------------------------------===//
@@ -1237,6 +1238,7 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
12371238
}];
12381239

12391240
let hasVerifier = 1;
1241+
let hasFolder = 1;
12401242
}
12411243

12421244
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
@@ -1261,6 +1263,7 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
12611263
}];
12621264

12631265
let hasVerifier = 1;
1266+
let hasFolder = 1;
12641267
}
12651268

12661269
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,26 @@ LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
362362
return success();
363363
}
364364

365+
//===----------------------------------------------------------------------===//
366+
// ComplexAttr definitions
367+
//===----------------------------------------------------------------------===//
368+
369+
LogicalResult ComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
370+
mlir::cir::ComplexType type,
371+
mlir::TypedAttr real, mlir::TypedAttr imag) {
372+
auto elemTy = type.getElementTy();
373+
if (real.getType() != elemTy) {
374+
emitError() << "type of the real part does not match the complex type";
375+
return failure();
376+
}
377+
if (imag.getType() != elemTy) {
378+
emitError() << "type of the imaginary part does not match the complex type";
379+
return failure();
380+
}
381+
382+
return success();
383+
}
384+
365385
//===----------------------------------------------------------------------===//
366386
// CmpThreeWayInfoAttr definitions
367387
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ void cir::CIRDialect::initialize() {
124124
addInterfaces<CIROpAsmDialectInterface>();
125125
}
126126

127+
Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
128+
mlir::Attribute value,
129+
mlir::Type type,
130+
mlir::Location loc) {
131+
return builder.create<mlir::cir::ConstantOp>(
132+
loc, type, mlir::cast<mlir::TypedAttr>(value));
133+
}
134+
127135
//===----------------------------------------------------------------------===//
128136
// Helpers
129137
//===----------------------------------------------------------------------===//
@@ -344,7 +352,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
344352
return success();
345353
}
346354

347-
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr>(attrType)) {
355+
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr, mlir::cir::ComplexAttr>(
356+
attrType)) {
348357
auto at = cast<TypedAttr>(attrType);
349358
if (at.getType() != opType) {
350359
return op->emitOpError("result type (")
@@ -578,6 +587,26 @@ LogicalResult ComplexCreateOp::verify() {
578587
return success();
579588
}
580589

590+
OpFoldResult ComplexCreateOp::fold(FoldAdaptor adaptor) {
591+
auto real = adaptor.getReal();
592+
auto imag = adaptor.getImag();
593+
594+
if (!real || !imag)
595+
return nullptr;
596+
597+
// When both of real and imag are constants, we can fold the operation into an
598+
// `cir.const #cir.complex` operation.
599+
600+
auto realAttr = mlir::cast<mlir::TypedAttr>(real);
601+
auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
602+
assert(realAttr.getType() == imagAttr.getType() &&
603+
"real part and imag part should be of the same type");
604+
605+
auto complexTy =
606+
mlir::cir::ComplexType::get(getContext(), realAttr.getType());
607+
return mlir::cir::ComplexAttr::get(complexTy, realAttr, imagAttr);
608+
}
609+
581610
//===----------------------------------------------------------------------===//
582611
// ComplexRealOp and ComplexImagOp
583612
//===----------------------------------------------------------------------===//
@@ -590,6 +619,14 @@ LogicalResult ComplexRealOp::verify() {
590619
return success();
591620
}
592621

622+
OpFoldResult ComplexRealOp::fold(FoldAdaptor adaptor) {
623+
auto input =
624+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
625+
if (input)
626+
return input.getReal();
627+
return nullptr;
628+
}
629+
593630
LogicalResult ComplexImagOp::verify() {
594631
if (getType() != getOperand().getType().getElementTy()) {
595632
emitOpError() << "cir.complex.imag result type does not match operand type";
@@ -598,6 +635,14 @@ LogicalResult ComplexImagOp::verify() {
598635
return success();
599636
}
600637

638+
OpFoldResult ComplexImagOp::fold(FoldAdaptor adaptor) {
639+
auto input =
640+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
641+
if (input)
642+
return input.getImag();
643+
return nullptr;
644+
}
645+
601646
//===----------------------------------------------------------------------===//
602647
// ComplexRealPtrOp and ComplexImagPtrOp
603648
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ void MergeCleanupsPass::runOnOperation() {
146146
getOperation()->walk([&](Operation *op) {
147147
// CastOp here is to perform a manual `fold` in
148148
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp>(op))
149+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, ComplexCreateOp,
150+
ComplexRealOp, ComplexImagOp>(op))
150151
ops.push_back(op);
151152
});
152153

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,30 @@ class CIRConstantLowering
11561156
attr = rewriter.getFloatAttr(
11571157
typeConverter->convertType(op.getType()),
11581158
mlir::cast<mlir::cir::FPAttr>(op.getValue()).getValue());
1159+
} else if (auto complexTy =
1160+
mlir::dyn_cast<mlir::cir::ComplexType>(op.getType())) {
1161+
auto complexAttr = mlir::cast<mlir::cir::ComplexAttr>(op.getValue());
1162+
auto complexElemTy = complexTy.getElementTy();
1163+
auto complexElemLLVMTy = typeConverter->convertType(complexElemTy);
1164+
1165+
mlir::Attribute components[2];
1166+
if (mlir::isa<mlir::cir::IntType>(complexElemTy)) {
1167+
components[0] = rewriter.getIntegerAttr(
1168+
complexElemLLVMTy,
1169+
mlir::cast<mlir::cir::IntAttr>(complexAttr.getReal()).getValue());
1170+
components[1] = rewriter.getIntegerAttr(
1171+
complexElemLLVMTy,
1172+
mlir::cast<mlir::cir::IntAttr>(complexAttr.getImag()).getValue());
1173+
} else {
1174+
components[0] = rewriter.getFloatAttr(
1175+
complexElemLLVMTy,
1176+
mlir::cast<mlir::cir::FPAttr>(complexAttr.getReal()).getValue());
1177+
components[1] = rewriter.getFloatAttr(
1178+
complexElemLLVMTy,
1179+
mlir::cast<mlir::cir::FPAttr>(complexAttr.getImag()).getValue());
1180+
}
1181+
1182+
attr = rewriter.getArrayAttr(components);
11591183
} else if (mlir::isa<mlir::cir::PointerType>(op.getType())) {
11601184
// Optimize with dedicated LLVM op for null pointers.
11611185
if (mlir::isa<mlir::cir::ConstPtrAttr>(op.getValue())) {

0 commit comments

Comments
 (0)