Skip to content

Commit 4ebe6f1

Browse files
Lancernlanza
authored andcommitted
[CIR][Transforms] Add folders for complex operations (#775)
This PR adds folders for `cir.complex.create`, `cir.complex.real`, and `cir.complex.imag`. This PR adds a new attribute `#cir.complex` that represents a constant complex value. Besides, the CIR dialect does not have a constant materializer yet; this PR adds it. Address #726 .
1 parent e7698a1 commit 4ebe6f1

File tree

10 files changed

+373
-106
lines changed

10 files changed

+373
-106
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,40 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
274274
}];
275275
}
276276

277+
//===----------------------------------------------------------------------===//
278+
// ComplexAttr
279+
//===----------------------------------------------------------------------===//
280+
281+
def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
282+
let summary = "An attribute that contains a constant complex value";
283+
let description = [{
284+
The `#cir.complex` attribute contains a constant value of complex number
285+
type. The `real` parameter gives the real part of the complex number and the
286+
`imag` parameter gives the imaginary part of the complex number.
287+
288+
The `real` and `imag` parameter must be either an IntAttr or an FPAttr that
289+
contains values of the same CIR type.
290+
}];
291+
292+
let parameters = (ins
293+
AttributeSelfTypeParameter<"", "mlir::cir::ComplexType">:$type,
294+
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
295+
296+
let builders = [
297+
AttrBuilderWithInferredContext<(ins "mlir::cir::ComplexType":$type,
298+
"mlir::TypedAttr":$real,
299+
"mlir::TypedAttr":$imag), [{
300+
return $_get(type.getContext(), type, real, imag);
301+
}]>,
302+
];
303+
304+
let genVerifyDecl = 1;
305+
306+
let assemblyFormat = [{
307+
`<` qualified($real) `,` qualified($imag) `>`
308+
}];
309+
}
310+
277311
//===----------------------------------------------------------------------===//
278312
// ConstPointerAttr
279313
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
2727
let useDefaultAttributePrinterParser = 0;
2828
let useDefaultTypePrinterParser = 0;
2929

30+
let hasConstantMaterializer = 1;
31+
3032
let extraClassDeclaration = [{
3133

3234
// Names of CIR parameter attributes.

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,7 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
12431243
}];
12441244

12451245
let hasVerifier = 1;
1246+
let hasFolder = 1;
12461247
}
12471248

12481249
//===----------------------------------------------------------------------===//
@@ -1271,6 +1272,7 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
12711272
}];
12721273

12731274
let hasVerifier = 1;
1275+
let hasFolder = 1;
12741276
}
12751277

12761278
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
@@ -1295,6 +1297,7 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
12951297
}];
12961298

12971299
let hasVerifier = 1;
1300+
let hasFolder = 1;
12981301
}
12991302

13001303
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,26 @@ LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
365365
return success();
366366
}
367367

368+
//===----------------------------------------------------------------------===//
369+
// ComplexAttr definitions
370+
//===----------------------------------------------------------------------===//
371+
372+
LogicalResult ComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
373+
mlir::cir::ComplexType type,
374+
mlir::TypedAttr real, mlir::TypedAttr imag) {
375+
auto elemTy = type.getElementTy();
376+
if (real.getType() != elemTy) {
377+
emitError() << "type of the real part does not match the complex type";
378+
return failure();
379+
}
380+
if (imag.getType() != elemTy) {
381+
emitError() << "type of the imaginary part does not match the complex type";
382+
return failure();
383+
}
384+
385+
return success();
386+
}
387+
368388
//===----------------------------------------------------------------------===//
369389
// CmpThreeWayInfoAttr definitions
370390
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ void cir::CIRDialect::initialize() {
124124
addInterfaces<CIROpAsmDialectInterface>();
125125
}
126126

127+
Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
128+
mlir::Attribute value,
129+
mlir::Type type,
130+
mlir::Location loc) {
131+
return builder.create<mlir::cir::ConstantOp>(
132+
loc, type, mlir::cast<mlir::TypedAttr>(value));
133+
}
134+
127135
//===----------------------------------------------------------------------===//
128136
// Helpers
129137
//===----------------------------------------------------------------------===//
@@ -344,7 +352,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
344352
return success();
345353
}
346354

347-
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr>(attrType)) {
355+
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr, mlir::cir::ComplexAttr>(
356+
attrType)) {
348357
auto at = cast<TypedAttr>(attrType);
349358
if (at.getType() != opType) {
350359
return op->emitOpError("result type (")
@@ -748,6 +757,26 @@ LogicalResult ComplexCreateOp::verify() {
748757
return success();
749758
}
750759

760+
OpFoldResult ComplexCreateOp::fold(FoldAdaptor adaptor) {
761+
auto real = adaptor.getReal();
762+
auto imag = adaptor.getImag();
763+
764+
if (!real || !imag)
765+
return nullptr;
766+
767+
// When both of real and imag are constants, we can fold the operation into an
768+
// `cir.const #cir.complex` operation.
769+
770+
auto realAttr = mlir::cast<mlir::TypedAttr>(real);
771+
auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
772+
assert(realAttr.getType() == imagAttr.getType() &&
773+
"real part and imag part should be of the same type");
774+
775+
auto complexTy =
776+
mlir::cir::ComplexType::get(getContext(), realAttr.getType());
777+
return mlir::cir::ComplexAttr::get(complexTy, realAttr, imagAttr);
778+
}
779+
751780
//===----------------------------------------------------------------------===//
752781
// ComplexRealOp and ComplexImagOp
753782
//===----------------------------------------------------------------------===//
@@ -760,6 +789,14 @@ LogicalResult ComplexRealOp::verify() {
760789
return success();
761790
}
762791

792+
OpFoldResult ComplexRealOp::fold(FoldAdaptor adaptor) {
793+
auto input =
794+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
795+
if (input)
796+
return input.getReal();
797+
return nullptr;
798+
}
799+
763800
LogicalResult ComplexImagOp::verify() {
764801
if (getType() != getOperand().getType().getElementTy()) {
765802
emitOpError() << "cir.complex.imag result type does not match operand type";
@@ -768,6 +805,14 @@ LogicalResult ComplexImagOp::verify() {
768805
return success();
769806
}
770807

808+
OpFoldResult ComplexImagOp::fold(FoldAdaptor adaptor) {
809+
auto input =
810+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
811+
if (input)
812+
return input.getImag();
813+
return nullptr;
814+
}
815+
771816
//===----------------------------------------------------------------------===//
772817
// ComplexRealPtrOp and ComplexImagPtrOp
773818
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ void CIRSimplifyPass::runOnOperation() {
146146
getOperation()->walk([&](Operation *op) {
147147
// CastOp here is to perform a manual `fold` in
148148
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp>(op))
149+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
150+
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
150151
ops.push_back(op);
151152
});
152153

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,30 @@ class CIRConstantLowering
11801180
attr = rewriter.getFloatAttr(
11811181
typeConverter->convertType(op.getType()),
11821182
mlir::cast<mlir::cir::FPAttr>(op.getValue()).getValue());
1183+
} else if (auto complexTy =
1184+
mlir::dyn_cast<mlir::cir::ComplexType>(op.getType())) {
1185+
auto complexAttr = mlir::cast<mlir::cir::ComplexAttr>(op.getValue());
1186+
auto complexElemTy = complexTy.getElementTy();
1187+
auto complexElemLLVMTy = typeConverter->convertType(complexElemTy);
1188+
1189+
mlir::Attribute components[2];
1190+
if (mlir::isa<mlir::cir::IntType>(complexElemTy)) {
1191+
components[0] = rewriter.getIntegerAttr(
1192+
complexElemLLVMTy,
1193+
mlir::cast<mlir::cir::IntAttr>(complexAttr.getReal()).getValue());
1194+
components[1] = rewriter.getIntegerAttr(
1195+
complexElemLLVMTy,
1196+
mlir::cast<mlir::cir::IntAttr>(complexAttr.getImag()).getValue());
1197+
} else {
1198+
components[0] = rewriter.getFloatAttr(
1199+
complexElemLLVMTy,
1200+
mlir::cast<mlir::cir::FPAttr>(complexAttr.getReal()).getValue());
1201+
components[1] = rewriter.getFloatAttr(
1202+
complexElemLLVMTy,
1203+
mlir::cast<mlir::cir::FPAttr>(complexAttr.getImag()).getValue());
1204+
}
1205+
1206+
attr = rewriter.getArrayAttr(components);
11831207
} else if (mlir::isa<mlir::cir::PointerType>(op.getType())) {
11841208
// Optimize with dedicated LLVM op for null pointers.
11851209
if (mlir::isa<mlir::cir::ConstPtrAttr>(op.getValue())) {

0 commit comments

Comments
 (0)