Skip to content

[CIR][Transform] Add simplify transformation for select op #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1395,11 +1395,20 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,

OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
auto condition = adaptor.getCondition();
if (!condition)
return nullptr;
if (condition) {
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
}

auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
// cir.select if %0 then x else x -> x
auto trueValue = adaptor.getTrueValue();
auto falseValue = adaptor.getFalseValue();
if (trueValue && trueValue == falseValue)
return trueValue;
if (getTrueValue() == getFalseValue())
return getTrueValue();

return nullptr;
}

//===----------------------------------------------------------------------===//
Expand Down
42 changes: 41 additions & 1 deletion clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,45 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
}
};

struct SimplifySelect : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter &rewriter) const final {
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
auto trueValueConstOp =
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(trueValueOp);
auto falseValueConstOp =
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(falseValueOp);
if (!trueValueConstOp || !falseValueConstOp)
return mlir::failure();

auto trueValue =
mlir::dyn_cast<mlir::cir::BoolAttr>(trueValueConstOp.getValue());
auto falseValue =
mlir::dyn_cast<mlir::cir::BoolAttr>(falseValueConstOp.getValue());
if (!trueValue || !falseValue)
return mlir::failure();

// cir.select if %0 then #true else #false -> %0
if (trueValue.getValue() && !falseValue.getValue()) {
rewriter.replaceAllUsesWith(op, op.getCondition());
rewriter.eraseOp(op);
return mlir::success();
}

// cir.seleft if %0 then #false else #true -> cir.unary not %0
if (!trueValue.getValue() && falseValue.getValue()) {
rewriter.replaceOpWithNewOp<mlir::cir::UnaryOp>(
op, mlir::cir::UnaryOpKind::Not, op.getCondition());
return mlir::success();
}

return mlir::failure();
}
};

//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
Expand All @@ -131,7 +170,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
RemoveRedundantBranches,
RemoveEmptyScope,
RemoveEmptySwitch,
RemoveTrivialTry
RemoveTrivialTry,
SimplifySelect
>(patterns.getContext());
// clang-format on
}
Expand Down
36 changes: 35 additions & 1 deletion clang/test/CIR/Transforms/select.cir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: cir-opt --canonicalize -o %t.cir %s
// RUN: cir-opt -cir-simplify -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>
Expand All @@ -23,4 +23,38 @@ module {
// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
// CHECK-NEXT: }

cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
%0 = cir.const #cir.int<42> : !s32i
%1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: cir.return %[[#A]] : !s32i
// CHECK-NEXT: }

cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<true> : !cir.bool
%1 = cir.const #cir.bool<false> : !cir.bool
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %2 : !cir.bool
}

// CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
// CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
// CHECK-NEXT: }

cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.const #cir.bool<true> : !cir.bool
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
cir.return %2 : !cir.bool
}

// CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
// CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
// CHECK-NEXT: cir.return %[[#A]] : !cir.bool
// CHECK-NEXT: }
}