Skip to content

Commit 6abac65

Browse files
committed
[CIR][Transforms] Add folders for complex operations
This patch adds folderss for `cir.complex.create`, `cir.complex.real`, and `cir.complex.imag`. This patch adds a new attribute `#cir.complex` that represents a constant complex value. Besides, the CIR dialect does not have a constant materializer yet; this patch adds it.
1 parent e136285 commit 6abac65

File tree

6 files changed

+149
-1
lines changed

6 files changed

+149
-1
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,40 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
273273
}];
274274
}
275275

276+
//===----------------------------------------------------------------------===//
277+
// ComplexAttr
278+
//===----------------------------------------------------------------------===//
279+
280+
def ComplexAttr : CIR_Attr<"Complex", "complex", [TypedAttrInterface]> {
281+
let summary = "An attribute that contains a constant complex value";
282+
let description = [{
283+
The `#cir.complex` attribute contains a constant value of complex number
284+
type. The `real` parameter gives the real part of the complex number and the
285+
`imag` parameter gives the imaginary part of the complex number.
286+
287+
The `real` and `imag` parameter must be either an IntAttr or an FPAttr that
288+
contains values of the same CIR type.
289+
}];
290+
291+
let parameters = (ins
292+
AttributeSelfTypeParameter<"", "mlir::cir::ComplexType">:$type,
293+
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
294+
295+
let builders = [
296+
AttrBuilderWithInferredContext<(ins "mlir::cir::ComplexType":$type,
297+
"mlir::TypedAttr":$real,
298+
"mlir::TypedAttr":$imag), [{
299+
return $_get(type.getContext(), type, real, imag);
300+
}]>,
301+
];
302+
303+
let genVerifyDecl = 1;
304+
305+
let assemblyFormat = [{
306+
`<` qualified($real) `,` qualified($imag) `>`
307+
}];
308+
}
309+
276310
//===----------------------------------------------------------------------===//
277311
// ConstPointerAttr
278312
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
2727
let useDefaultAttributePrinterParser = 0;
2828
let useDefaultTypePrinterParser = 0;
2929

30+
let hasConstantMaterializer = 1;
31+
3032
let extraClassDeclaration = [{
3133

3234
// Names of CIR parameter attributes.

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
12091209
}];
12101210

12111211
let hasVerifier = 1;
1212+
let hasFolder = 1;
12121213
}
12131214

12141215
//===----------------------------------------------------------------------===//
@@ -1237,6 +1238,7 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
12371238
}];
12381239

12391240
let hasVerifier = 1;
1241+
let hasFolder = 1;
12401242
}
12411243

12421244
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
@@ -1261,6 +1263,7 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
12611263
}];
12621264

12631265
let hasVerifier = 1;
1266+
let hasFolder = 1;
12641267
}
12651268

12661269
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,26 @@ LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
362362
return success();
363363
}
364364

365+
//===----------------------------------------------------------------------===//
366+
// ComplexAttr definitions
367+
//===----------------------------------------------------------------------===//
368+
369+
LogicalResult ComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
370+
mlir::cir::ComplexType type,
371+
mlir::TypedAttr real, mlir::TypedAttr imag) {
372+
auto elemTy = type.getElementTy();
373+
if (real.getType() != elemTy) {
374+
emitError() << "type of the real part does not match the complex type";
375+
return failure();
376+
}
377+
if (imag.getType() != elemTy) {
378+
emitError() << "type of the imaginary part does not match the complex type";
379+
return failure();
380+
}
381+
382+
return success();
383+
}
384+
365385
//===----------------------------------------------------------------------===//
366386
// CmpThreeWayInfoAttr definitions
367387
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ void cir::CIRDialect::initialize() {
124124
addInterfaces<CIROpAsmDialectInterface>();
125125
}
126126

127+
Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
128+
mlir::Attribute value,
129+
mlir::Type type,
130+
mlir::Location loc) {
131+
return builder.create<mlir::cir::ConstantOp>(
132+
loc, type, mlir::cast<mlir::TypedAttr>(value));
133+
}
134+
127135
//===----------------------------------------------------------------------===//
128136
// Helpers
129137
//===----------------------------------------------------------------------===//
@@ -344,7 +352,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
344352
return success();
345353
}
346354

347-
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr>(attrType)) {
355+
if (mlir::isa<mlir::cir::IntAttr, mlir::cir::FPAttr, mlir::cir::ComplexAttr>(
356+
attrType)) {
348357
auto at = cast<TypedAttr>(attrType);
349358
if (at.getType() != opType) {
350359
return op->emitOpError("result type (")
@@ -578,6 +587,26 @@ LogicalResult ComplexCreateOp::verify() {
578587
return success();
579588
}
580589

590+
OpFoldResult ComplexCreateOp::fold(FoldAdaptor adaptor) {
591+
auto real = adaptor.getReal();
592+
auto imag = adaptor.getImag();
593+
594+
if (!real || !imag)
595+
return nullptr;
596+
597+
// When both of real and imag are constants, we can fold the operation into an
598+
// `cir.const #cir.complex` operation.
599+
600+
auto realAttr = mlir::cast<mlir::TypedAttr>(real);
601+
auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
602+
assert(realAttr.getType() == imagAttr.getType() &&
603+
"real part and imag part should be of the same type");
604+
605+
auto complexTy =
606+
mlir::cir::ComplexType::get(getContext(), realAttr.getType());
607+
return mlir::cir::ComplexAttr::get(complexTy, realAttr, imagAttr);
608+
}
609+
581610
//===----------------------------------------------------------------------===//
582611
// ComplexRealOp and ComplexImagOp
583612
//===----------------------------------------------------------------------===//
@@ -590,6 +619,14 @@ LogicalResult ComplexRealOp::verify() {
590619
return success();
591620
}
592621

622+
OpFoldResult ComplexRealOp::fold(FoldAdaptor adaptor) {
623+
auto input =
624+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
625+
if (input)
626+
return input.getReal();
627+
return nullptr;
628+
}
629+
593630
LogicalResult ComplexImagOp::verify() {
594631
if (getType() != getOperand().getType().getElementTy()) {
595632
emitOpError() << "cir.complex.imag result type does not match operand type";
@@ -598,6 +635,14 @@ LogicalResult ComplexImagOp::verify() {
598635
return success();
599636
}
600637

638+
OpFoldResult ComplexImagOp::fold(FoldAdaptor adaptor) {
639+
auto input =
640+
mlir::cast_if_present<mlir::cir::ComplexAttr>(adaptor.getOperand());
641+
if (input)
642+
return input.getImag();
643+
return nullptr;
644+
}
645+
601646
//===----------------------------------------------------------------------===//
602647
// ComplexRealPtrOp and ComplexImagPtrOp
603648
//===----------------------------------------------------------------------===//
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: cir-opt --canonicalize -o %t.cir %s
2+
// RUN: FileCheck --input-file %t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @complex_create_fold() -> !cir.complex<!s32i> {
8+
%0 = cir.const #cir.int<1> : !s32i
9+
%1 = cir.const #cir.int<2> : !s32i
10+
%2 = cir.complex.create %0, %1 : !s32i -> !cir.complex<!s32i>
11+
cir.return %2 : !cir.complex<!s32i>
12+
}
13+
14+
// CHECK-LABEL: cir.func @complex_create_fold() -> !cir.complex<!s32i> {
15+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
16+
// CHECK-NEXT: cir.return %[[#A]] : !cir.complex<!s32i>
17+
// CHECK-NEXT: }
18+
19+
cir.func @fold_complex_real() -> !s32i {
20+
%0 = cir.const #cir.int<1> : !s32i
21+
%1 = cir.const #cir.int<2> : !s32i
22+
%2 = cir.complex.create %0, %1 : !s32i -> !cir.complex<!s32i>
23+
%3 = cir.complex.real %2 : !cir.complex<!s32i> -> !s32i
24+
cir.return %3 : !s32i
25+
}
26+
27+
// CHECK-LABEL: cir.func @fold_complex_real() -> !s32i {
28+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
29+
// CHECK-NEXT: cir.return %[[#A]] : !s32i
30+
// CHECK-NEXT: }
31+
32+
cir.func @fold_complex_imag() -> !s32i {
33+
%0 = cir.const #cir.int<1> : !s32i
34+
%1 = cir.const #cir.int<2> : !s32i
35+
%2 = cir.complex.create %0, %1 : !s32i -> !cir.complex<!s32i>
36+
%3 = cir.complex.imag %2 : !cir.complex<!s32i> -> !s32i
37+
cir.return %3 : !s32i
38+
}
39+
40+
// CHECK-LABEL: cir.func @fold_complex_imag() -> !s32i {
41+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<2> : !s32i
42+
// CHECK-NEXT: cir.return %[[#A]] : !s32i
43+
// CHECK-NEXT: }
44+
}

0 commit comments

Comments
 (0)