Skip to content

Commit b6e5f0d

Browse files
authored
[CIR] Introduce CIR simplification (#696)
This PR introduce cir simplification pass. The idea is to have a pass for the redundant operations removal/update. Right now two pattern implemented, both related to the redundant `bool` operations. First pattern removes redundant casts from `bool` to `int` and back that for some reasons appear in the code. Second pattern removes sequential unary not operations (`!!`) . For example, the code from the test is expanded from the simple check that is common for C code: ``` #define CHECK_PTR(ptr) \ do { \ if (__builtin_expect((!!((ptr) == 0)), 0))\ return -42; \ } while(0) ``` I mark this PR as a draft for the following reasons: 1) I have no idea if it's useful for the community 2) There is a test fail - unfortunately current pattern rewriter run DCE underneath the hood and looks like we can't workaround it. It's enough just to add an operation to the list - in this case `UnaryOp` - and call `applyOpPatternsAndFold`. I could rewrite a test a little in order to make everything non dead or implement a simple fix point algorithm for the particular task (I would do the former).
1 parent e288169 commit b6e5f0d

File tree

13 files changed

+214
-57
lines changed

13 files changed

+214
-57
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,7 @@ def UnaryOp : CIR_Op<"unary", [Pure, SameOperandsAndResultType]> {
991991
}];
992992

993993
let hasVerifier = 1;
994+
let hasFolder = 1;
994995
}
995996

996997
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ std::unique_ptr<Pass> createLifetimeCheckPass(ArrayRef<StringRef> remark,
2626
ArrayRef<StringRef> hist,
2727
unsigned hist_limit,
2828
clang::ASTContext *astCtx);
29-
std::unique_ptr<Pass> createMergeCleanupsPass();
29+
std::unique_ptr<Pass> createCIRSimplifyPass();
3030
std::unique_ptr<Pass> createDropASTPass();
3131
std::unique_ptr<Pass> createSCFPreparePass();
3232
std::unique_ptr<Pass> createLoweringPreparePass();

clang/include/clang/CIR/Dialect/Passes.td

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,24 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def MergeCleanups : Pass<"cir-merge-cleanups"> {
15-
let summary = "Remove unnecessary branches to cleanup blocks";
14+
def CIRSimplify : Pass<"cir-simplify"> {
15+
let summary = "Performs CIR simplification";
1616
let description = [{
17-
Canonicalize pass is too aggressive for CIR when the pipeline is
18-
used for C/C++ analysis. This pass runs some rewrites for scopes,
19-
merging some blocks and eliminating unnecessary control-flow.
17+
The pass rewrites CIR and removes some redundant operations.
18+
19+
For example, due to canonicalize pass is too aggressive for CIR when
20+
the pipeline is used for C/C++ analysis, this pass runs some rewrites
21+
for scopes, merging some blocks and eliminating unnecessary control-flow.
22+
23+
Also, the pass removes redundant and/or unneccessary cast and unary not
24+
operation e.g.
25+
```mlir
26+
%1 = cir.cast(bool_to_int, %0 : !cir.bool), !s32i
27+
%2 = cir.cast(int_to_bool, %1 : !s32i), !cir.bool
28+
```
29+
2030
}];
21-
let constructor = "mlir::createMergeCleanupsPass()";
31+
let constructor = "mlir::createCIRSimplifyPass()";
2232
let dependentDialects = ["cir::CIRDialect"];
2333
}
2434

clang/lib/CIR/CodeGen/CIRPasses.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
#include "mlir/Support/LogicalResult.h"
2020
#include "mlir/Transforms/Passes.h"
2121

22-
#include <iostream>
23-
2422
namespace cir {
2523
mlir::LogicalResult runCIRToCIRPasses(
2624
mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
@@ -32,7 +30,7 @@ mlir::LogicalResult runCIRToCIRPasses(
3230
bool enableMem2Reg) {
3331

3432
mlir::PassManager pm(mlirCtx);
35-
pm.addPass(mlir::createMergeCleanupsPass());
33+
pm.addPass(mlir::createCIRSimplifyPass());
3634

3735
// TODO(CIR): Make this actually propagate errors correctly. This is stubbed
3836
// in to get rebases going.

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -528,26 +528,86 @@ LogicalResult CastOp::verify() {
528528
llvm_unreachable("Unknown CastOp kind?");
529529
}
530530

531-
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
532-
if (getSrc().getType() != getResult().getType())
533-
return {};
534-
switch (getKind()) {
535-
case mlir::cir::CastKind::integral: {
536-
// TODO: for sign differences, it's possible in certain conditions to
537-
// create a new attribute that's capable of representing the source.
538-
SmallVector<mlir::OpFoldResult, 1> foldResults;
539-
auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
540-
if (foldOrder.succeeded() && foldResults[0].is<mlir::Attribute>())
541-
return foldResults[0].get<mlir::Attribute>();
542-
return {};
543-
}
544-
case mlir::cir::CastKind::bitcast:
545-
case mlir::cir::CastKind::address_space: {
546-
return getSrc();
531+
bool isIntOrBoolCast(mlir::cir::CastOp op) {
532+
auto kind = op.getKind();
533+
return kind == mlir::cir::CastKind::bool_to_int ||
534+
kind == mlir::cir::CastKind::int_to_bool ||
535+
kind == mlir::cir::CastKind::integral;
536+
}
537+
538+
Value tryFoldCastChain(CastOp op) {
539+
CastOp head = op, tail = op;
540+
541+
while(op) {
542+
if (!isIntOrBoolCast(op))
543+
break;
544+
head = op;
545+
op = dyn_cast_or_null<CastOp>(head.getSrc().getDefiningOp());
547546
}
548-
default:
547+
548+
if (head == tail)
549549
return {};
550+
551+
// if bool_to_int -> ... -> int_to_bool: take the bool
552+
// as we had it was before all casts
553+
if (head.getKind() == mlir::cir::CastKind::bool_to_int &&
554+
tail.getKind() == mlir::cir::CastKind::int_to_bool)
555+
return head.getSrc();
556+
557+
// if int_to_bool -> ... -> int_to_bool: take the result
558+
// of the first one, as no other casts (and ext casts as well)
559+
// don't change the first result
560+
if (head.getKind() == mlir::cir::CastKind::int_to_bool &&
561+
tail.getKind() == mlir::cir::CastKind::int_to_bool)
562+
return head.getResult();
563+
564+
return {};
565+
}
566+
567+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
568+
if (getSrc().getType() == getResult().getType()) {
569+
switch (getKind()) {
570+
case mlir::cir::CastKind::integral: {
571+
// TODO: for sign differences, it's possible in certain conditions to
572+
// create a new attribute that's capable of representing the source.
573+
SmallVector<mlir::OpFoldResult, 1> foldResults;
574+
auto foldOrder = getSrc().getDefiningOp()->fold(foldResults);
575+
if (foldOrder.succeeded() && foldResults[0].is<mlir::Attribute>())
576+
return foldResults[0].get<mlir::Attribute>();
577+
return {};
578+
}
579+
case mlir::cir::CastKind::bitcast:
580+
case mlir::cir::CastKind::address_space: {
581+
return getSrc();
582+
}
583+
default:
584+
return {};
585+
}
550586
}
587+
return tryFoldCastChain(*this);
588+
}
589+
590+
static bool isBoolNot(mlir::cir::UnaryOp op) {
591+
return isa<BoolType>(op.getInput().getType()) &&
592+
op.getKind() == mlir::cir::UnaryOpKind::Not;
593+
}
594+
595+
/* This folder simplifies the sequential boolean not operations.
596+
For instance, the next two unary operations will be eliminated:
597+
598+
```mlir
599+
%1 = cir.unary(not, %0) : !cir.bool, !cir.bool
600+
%2 = cir.unary(not, %1) : !cir.bool, !cir.bool
601+
```
602+
603+
and the argument of the first one (%0) will be used instead. */
604+
OpFoldResult UnaryOp::fold(FoldAdaptor adaptor) {
605+
if (isBoolNot(*this))
606+
if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp()))
607+
if (isBoolNot(previous))
608+
return previous.getInput();
609+
610+
return {};
551611
}
552612

553613
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp renamed to clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- MergeCleanups.cpp - merge simple return/yield blocks ---------------===//
1+
//===- CIRSimplify.cpp - performs CIR simplification ----------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -108,11 +108,11 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
108108
};
109109

110110
//===----------------------------------------------------------------------===//
111-
// MergeCleanupsPass
111+
// CIRSimplifyPass
112112
//===----------------------------------------------------------------------===//
113113

114-
struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
115-
using MergeCleanupsBase::MergeCleanupsBase;
114+
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
115+
using CIRSimplifyBase::CIRSimplifyBase;
116116

117117
// The same operation rewriting done here could have been performed
118118
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and
@@ -136,7 +136,7 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
136136
// clang-format on
137137
}
138138

139-
void MergeCleanupsPass::runOnOperation() {
139+
void CIRSimplifyPass::runOnOperation() {
140140
// Collect rewrite patterns.
141141
RewritePatternSet patterns(&getContext());
142142
populateMergeCleanupPatterns(patterns);
@@ -146,7 +146,7 @@ void MergeCleanupsPass::runOnOperation() {
146146
getOperation()->walk([&](Operation *op) {
147147
// CastOp here is to perform a manual `fold` in
148148
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp>(op))
149+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp>(op))
150150
ops.push_back(op);
151151
});
152152

@@ -157,6 +157,6 @@ void MergeCleanupsPass::runOnOperation() {
157157

158158
} // namespace
159159

160-
std::unique_ptr<Pass> mlir::createMergeCleanupsPass() {
161-
return std::make_unique<MergeCleanupsPass>();
160+
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
161+
return std::make_unique<CIRSimplifyPass>();
162162
}

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ add_subdirectory(TargetLowering)
33
add_clang_library(MLIRCIRTransforms
44
LifetimeCheck.cpp
55
LoweringPrepare.cpp
6-
MergeCleanups.cpp
6+
CIRSimplify.cpp
77
DropAST.cpp
88
IdiomRecognizer.cpp
99
LibOpt.cpp

clang/test/CIR/CodeGen/unary.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,28 +155,28 @@ int *inc_p(int *i) {
155155

156156
void floats(float f) {
157157
// CHECK: cir.func @{{.+}}floats{{.+}}
158-
+f; // CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.float, !cir.float
159-
-f; // CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.float, !cir.float
158+
f = +f; // CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.float, !cir.float
159+
f = -f; // CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.float, !cir.float
160160
++f; // CHECK: = cir.unary(inc, %{{[0-9]+}}) : !cir.float, !cir.float
161161
--f; // CHECK: = cir.unary(dec, %{{[0-9]+}}) : !cir.float, !cir.float
162162
f++; // CHECK: = cir.unary(inc, %{{[0-9]+}}) : !cir.float, !cir.float
163163
f--; // CHECK: = cir.unary(dec, %{{[0-9]+}}) : !cir.float, !cir.float
164164

165-
!f;
165+
f = !f;
166166
// CHECK: %[[#F_BOOL:]] = cir.cast(float_to_bool, %{{[0-9]+}} : !cir.float), !cir.bool
167167
// CHECK: = cir.unary(not, %[[#F_BOOL]]) : !cir.bool, !cir.bool
168168
}
169169

170170
void doubles(double d) {
171171
// CHECK: cir.func @{{.+}}doubles{{.+}}
172-
+d; // CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.double, !cir.double
173-
-d; // CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.double, !cir.double
172+
d = +d; // CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.double, !cir.double
173+
d = -d; // CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.double, !cir.double
174174
++d; // CHECK: = cir.unary(inc, %{{[0-9]+}}) : !cir.double, !cir.double
175175
--d; // CHECK: = cir.unary(dec, %{{[0-9]+}}) : !cir.double, !cir.double
176176
d++; // CHECK: = cir.unary(inc, %{{[0-9]+}}) : !cir.double, !cir.double
177177
d--; // CHECK: = cir.unary(dec, %{{[0-9]+}}) : !cir.double, !cir.double
178178

179-
!d;
179+
d = !d;
180180
// CHECK: %[[#D_BOOL:]] = cir.cast(float_to_bool, %{{[0-9]+}} : !cir.double), !cir.bool
181181
// CHECK: = cir.unary(not, %[[#D_BOOL]]) : !cir.bool, !cir.bool
182182
}
@@ -185,7 +185,7 @@ void pointers(int *p) {
185185
// CHECK: cir.func @{{[^ ]+}}pointers
186186
// CHECK: %[[#P:]] = cir.alloca !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
187187

188-
+p;
188+
p = +p;
189189
// CHECK: cir.unary(plus, %{{.+}}) : !cir.ptr<!s32i>, !cir.ptr<!s32i>
190190

191191
++p;
@@ -205,28 +205,28 @@ void pointers(int *p) {
205205
// CHECK: %[[#RES:]] = cir.ptr_stride(%{{.+}} : !cir.ptr<!s32i>, %[[#DEC]] : !s32i), !cir.ptr<!s32i>
206206
// CHECK: cir.store %[[#RES]], %[[#P]] : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
207207

208-
!p;
208+
bool p1 = !p;
209209
// %[[BOOLPTR:]] = cir.cast(ptr_to_bool, %15 : !cir.ptr<!s32i>), !cir.bool
210210
// cir.unary(not, %[[BOOLPTR]]) : !cir.bool, !cir.bool
211211
}
212212

213213
void chars(char c) {
214214
// CHECK: cir.func @{{.+}}chars{{.+}}
215215

216-
+c;
216+
int c1 = +c;
217217
// CHECK: %[[#PROMO:]] = cir.cast(integral, %{{.+}} : !s8i), !s32i
218218
// CHECK: cir.unary(plus, %[[#PROMO]]) : !s32i, !s32i
219-
-c;
219+
int c2 = -c;
220220
// CHECK: %[[#PROMO:]] = cir.cast(integral, %{{.+}} : !s8i), !s32i
221221
// CHECK: cir.unary(minus, %[[#PROMO]]) : !s32i, !s32i
222222

223223
// Chars can go through some integer promotion codegen paths even when not promoted.
224-
++c; // CHECK: cir.unary(inc, %7) : !s8i, !s8i
225-
--c; // CHECK: cir.unary(dec, %9) : !s8i, !s8i
226-
c++; // CHECK: cir.unary(inc, %11) : !s8i, !s8i
227-
c--; // CHECK: cir.unary(dec, %13) : !s8i, !s8i
224+
++c; // CHECK: cir.unary(inc, %10) : !s8i, !s8i
225+
--c; // CHECK: cir.unary(dec, %12) : !s8i, !s8i
226+
c++; // CHECK: cir.unary(inc, %14) : !s8i, !s8i
227+
c--; // CHECK: cir.unary(dec, %16) : !s8i, !s8i
228228

229-
!c;
229+
bool c3 = !c;
230230
// CHECK: %[[#C_BOOL:]] = cir.cast(int_to_bool, %{{[0-9]+}} : !s8i), !cir.bool
231231
// CHECK: cir.unary(not, %[[#C_BOOL]]) : !cir.bool, !cir.bool
232232
}

clang/test/CIR/Transforms/merge-cleanups.cir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: cir-opt %s -cir-merge-cleanups -o %t.out.cir
1+
// RUN: cir-opt %s -cir-simplify -o %t.out.cir
22
// RUN: FileCheck --input-file=%t.out.cir %s
33

44
#false = #cir.bool<false> : !cir.bool

clang/test/CIR/Transforms/simpl.c

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=BEFORE %s
2+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=AFTER %s
3+
4+
5+
#define CHECK_PTR(ptr) \
6+
do { \
7+
if (__builtin_expect((!!((ptr) == 0)), 0))\
8+
return -42; \
9+
} while(0)
10+
11+
int foo(int* ptr) {
12+
CHECK_PTR(ptr);
13+
14+
(*ptr)++;
15+
return 0;
16+
}
17+
18+
// BEFORE: cir.func {{.*@foo}}
19+
// BEFORE: [[X0:%.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
20+
// BEFORE: [[X1:%.*]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i>
21+
// BEFORE: [[X2:%.*]] = cir.cmp(eq, [[X0]], [[X1]]) : !cir.ptr<!s32i>, !s32i
22+
// BEFORE: [[X3:%.*]] = cir.cast(int_to_bool, [[X2]] : !s32i), !cir.bool
23+
// BEFORE: [[X4:%.*]] = cir.unary(not, [[X3]]) : !cir.bool, !cir.bool
24+
// BEFORE: [[X5:%.*]] = cir.cast(bool_to_int, [[X4]] : !cir.bool), !s32i
25+
// BEFORE: [[X6:%.*]] = cir.cast(int_to_bool, [[X5]] : !s32i), !cir.bool
26+
// BEFORE: [[X7:%.*]] = cir.unary(not, [[X6]]) : !cir.bool, !cir.bool
27+
// BEFORE: [[X8:%.*]] = cir.cast(bool_to_int, [[X7]] : !cir.bool), !s32i
28+
// BEFORE: [[X9:%.*]] = cir.cast(integral, [[X8]] : !s32i), !s64i
29+
// BEFORE: [[X10:%.*]] = cir.const #cir.int<0> : !s32i
30+
// BEFORE: [[X11:%.*]] = cir.cast(integral, [[X10]] : !s32i), !s64i
31+
// BEFORE: [[X12:%.*]] = cir.cast(int_to_bool, [[X9]] : !s64i), !cir.bool
32+
// BEFORE: cir.if [[X12]]
33+
34+
// AFTER: [[X0:%.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
35+
// AFTER: [[X1:%.*]] = cir.const #cir.ptr<null> : !cir.ptr<!s32i>
36+
// AFTER: [[X2:%.*]] = cir.cmp(eq, [[X0]], [[X1]]) : !cir.ptr<!s32i>, !s32i
37+
// AFTER: [[X3:%.*]] = cir.cast(int_to_bool, [[X2]] : !s32i), !cir.bool
38+
// AFTER: cir.if [[X3]]

0 commit comments

Comments
 (0)