Skip to content

Commit 5711854

Browse files
committed
[CIR][LibOpt] Add a first transformation: std::find to memchr
Inspired by similar work in libc++, pointed to me by Louis Dionne and Nikolas Klauser. This is initial, very conservative and not generalized yet: works for `char`s within a specific version of `std::find`.
1 parent 1b4c9be commit 5711854

File tree

8 files changed

+225
-17
lines changed

8 files changed

+225
-17
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
4747
getAttr<mlir::cir::IntAttr>(typ, val));
4848
}
4949

50+
mlir::cir::PointerType getVoidPtrTy(unsigned AddrSpace = 0) {
51+
if (AddrSpace)
52+
llvm_unreachable("address space is NYI");
53+
return ::mlir::cir::PointerType::get(
54+
getContext(), ::mlir::cir::VoidType::get(getContext()));
55+
}
56+
5057
mlir::Value createNot(mlir::Value value) {
5158
return create<mlir::cir::UnaryOp>(value.getLoc(), value.getType(),
5259
mlir::cir::UnaryOpKind::Not, value);

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
390390
return mlir::cir::PointerType::get(getContext(), ty);
391391
}
392392

393-
mlir::cir::PointerType getVoidPtrTy(unsigned AddrSpace = 0) {
394-
if (AddrSpace)
395-
llvm_unreachable("address space is NYI");
396-
return typeCache.VoidPtrTy;
397-
}
398-
399393
/// Get a CIR anonymous struct type.
400394
mlir::cir::StructType
401395
getAnonStructTy(llvm::ArrayRef<mlir::Type> members, bool packed = false,

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_clang_library(MLIRCIRTransforms
55
DropAST.cpp
66
IdiomRecognizer.cpp
77
LibOpt.cpp
8+
StdHelpers.cpp
89

910
DEPENDS
1011
MLIRCIRPassIncGen

clang/lib/CIR/Dialect/Transforms/IdiomRecognizer.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "llvm/Support/ErrorHandling.h"
2525
#include "llvm/Support/Path.h"
2626

27+
#include "StdHelpers.h"
28+
2729
using cir::CIRBaseBuilderTy;
2830
using namespace mlir;
2931
using namespace mlir::cir;
@@ -120,20 +122,10 @@ static bool isIteratorLikeType(mlir::Type t) {
120122
}
121123

122124
static bool isIteratorInStdContainter(mlir::Type t) {
123-
auto sTy = t.dyn_cast<StructType>();
124-
if (!sTy)
125-
return false;
126-
auto recordDecl = sTy.getAst();
127-
if (!recordDecl.isInStdNamespace())
128-
return false;
129-
130125
// TODO: only std::array supported for now, generalize and
131126
// use tablegen. CallDescription.cpp in the static analyzer
132127
// could be a good inspiration source too.
133-
if (recordDecl.getName().compare("array") != 0)
134-
return false;
135-
136-
return true;
128+
return isStdArrayType(t);
137129
}
138130

139131
void IdiomRecognizerPass::raiseIteratorBeginEnd(CallOp call) {

clang/lib/CIR/Dialect/Transforms/LibOpt.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "llvm/Support/ErrorHandling.h"
2525
#include "llvm/Support/Path.h"
2626

27+
#include "StdHelpers.h"
28+
2729
using cir::CIRBaseBuilderTy;
2830
using namespace mlir;
2931
using namespace mlir::cir;
@@ -33,6 +35,7 @@ namespace {
3335
struct LibOptPass : public LibOptBase<LibOptPass> {
3436
LibOptPass() = default;
3537
void runOnOperation() override;
38+
void xformStdFindIntoMemchr(StdFindOp findOp);
3639

3740
// Handle pass options
3841
struct Options {
@@ -83,12 +86,127 @@ struct LibOptPass : public LibOptBase<LibOptPass> {
8386
};
8487
} // namespace
8588

89+
static bool isSequentialContainer(mlir::Type t) {
90+
// TODO: other sequential ones, vector, dequeue, list, forward_list.
91+
return isStdArrayType(t);
92+
}
93+
94+
static bool getIntegralNTTPAt(StructType t, size_t pos, unsigned &size) {
95+
auto *d =
96+
dyn_cast<clang::ClassTemplateSpecializationDecl>(t.getAst().getRawDecl());
97+
if (!d)
98+
return false;
99+
100+
auto &templArgs = d->getTemplateArgs();
101+
if (pos >= templArgs.size())
102+
return false;
103+
104+
auto arraySizeTemplateArg = templArgs[pos];
105+
if (arraySizeTemplateArg.getKind() != clang::TemplateArgument::Integral)
106+
return false;
107+
108+
size = arraySizeTemplateArg.getAsIntegral().getSExtValue();
109+
return true;
110+
}
111+
112+
static bool containerHasStaticSize(StructType t, unsigned &size) {
113+
// TODO: add others.
114+
if (!isStdArrayType(t))
115+
return false;
116+
117+
// Get "size" from std::array<T, size>
118+
unsigned sizeNTTPPos = 1;
119+
return getIntegralNTTPAt(t, sizeNTTPPos, size);
120+
}
121+
122+
void LibOptPass::xformStdFindIntoMemchr(StdFindOp findOp) {
123+
// First and second operands need to be iterators begin() and end().
124+
// TODO: look over cir.loads until we have a mem2reg + other passes
125+
// to help out here.
126+
auto iterBegin = dyn_cast<IterBeginOp>(findOp.getOperand(0).getDefiningOp());
127+
if (!iterBegin)
128+
return;
129+
if (!isa<IterEndOp>(findOp.getOperand(1).getDefiningOp()))
130+
return;
131+
132+
// Both operands have the same type, use iterBegin.
133+
134+
// Look at this pointer to retrieve container information.
135+
auto thisPtr =
136+
iterBegin.getOperand().getType().cast<PointerType>().getPointee();
137+
auto containerTy = dyn_cast<StructType>(thisPtr);
138+
if (!containerTy)
139+
return;
140+
141+
if (!isSequentialContainer(containerTy))
142+
return;
143+
144+
unsigned staticSize = 0;
145+
if (!containerHasStaticSize(containerTy, staticSize))
146+
return;
147+
148+
// Transformation:
149+
// - 1st arg: the data pointer
150+
// - Assert the Iterator is a pointer to primitive type.
151+
// - Check IterBeginOp is char sized. TODO: add other types that map to
152+
// char size.
153+
auto iterResTy = iterBegin.getResult().getType().dyn_cast<PointerType>();
154+
assert(iterResTy && "expected pointer type for iterator");
155+
auto underlyingDataTy = iterResTy.getPointee().dyn_cast<mlir::cir::IntType>();
156+
if (!underlyingDataTy || underlyingDataTy.getWidth() != 8)
157+
return;
158+
159+
// - 2nd arg: the pattern
160+
// - Check it's a pointer type.
161+
// - Load the pattern from memory
162+
// - cast it to `int`.
163+
auto patternAddrTy = findOp.getOperand(2).getType().dyn_cast<PointerType>();
164+
if (!patternAddrTy || patternAddrTy.getPointee() != underlyingDataTy)
165+
return;
166+
167+
// - 3rd arg: the size
168+
// - Create and pass a cir.const with NTTP value
169+
170+
CIRBaseBuilderTy builder(getContext());
171+
builder.setInsertionPointAfter(findOp.getOperation());
172+
auto memchrOp0 = builder.createBitcast(
173+
iterBegin.getLoc(), iterBegin.getResult(), builder.getVoidPtrTy());
174+
175+
// FIXME: get datalayout based "int" instead of fixed size 4.
176+
auto loadPattern = builder.create<LoadOp>(
177+
findOp.getOperand(2).getLoc(), underlyingDataTy, findOp.getOperand(2));
178+
auto memchrOp1 = builder.createIntCast(
179+
loadPattern, IntType::get(builder.getContext(), 32, true));
180+
181+
// FIXME: get datalayout based "size_t" instead of fixed size 64.
182+
auto uInt64Ty = IntType::get(builder.getContext(), 64, false);
183+
auto memchrOp2 = builder.create<ConstantOp>(
184+
findOp.getLoc(), uInt64Ty, mlir::cir::IntAttr::get(uInt64Ty, staticSize));
185+
186+
// Build memchr op:
187+
// void *memchr(const void *s, int c, size_t n);
188+
auto memChr = builder.create<MemChrOp>(findOp.getLoc(), memchrOp0, memchrOp1,
189+
memchrOp2);
190+
mlir::Operation *result =
191+
builder.createBitcast(findOp.getLoc(), memChr.getResult(), iterResTy)
192+
.getDefiningOp();
193+
194+
findOp.replaceAllUsesWith(result);
195+
findOp.erase();
196+
}
197+
86198
void LibOptPass::runOnOperation() {
87199
assert(astCtx && "Missing ASTContext, please construct with the right ctor");
88200
opts.parseOptions(*this);
89201
auto *op = getOperation();
90202
if (isa<::mlir::ModuleOp>(op))
91203
theModule = cast<::mlir::ModuleOp>(op);
204+
205+
SmallVector<StdFindOp> stdFindToTransform;
206+
op->walk([&](StdFindOp findOp) { stdFindToTransform.push_back(findOp); });
207+
208+
for (auto c : stdFindToTransform)
209+
xformStdFindIntoMemchr(c);
92210
}
93211

94212
std::unique_ptr<Pass> mlir::createLibOptPass() {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- StdHelpers.cpp - Implementation standard related helpers--*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "StdHelpers.h"
10+
11+
namespace mlir {
12+
namespace cir {
13+
14+
bool isStdArrayType(mlir::Type t) {
15+
auto sTy = t.dyn_cast<StructType>();
16+
if (!sTy)
17+
return false;
18+
auto recordDecl = sTy.getAst();
19+
if (!recordDecl.isInStdNamespace())
20+
return false;
21+
22+
// TODO: only std::array supported for now, generalize and
23+
// use tablegen. CallDescription.cpp in the static analyzer
24+
// could be a good inspiration source too.
25+
if (recordDecl.getName().compare("array") != 0)
26+
return false;
27+
28+
return true;
29+
}
30+
31+
} // namespace cir
32+
} // namespace mlir
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- StdHelpers.h - Helpers for standard types/functions ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/IR/BuiltinAttributes.h"
11+
#include "mlir/IR/Region.h"
12+
#include "clang/AST/ASTContext.h"
13+
#include "clang/Basic/Module.h"
14+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
15+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
16+
#include "clang/CIR/Dialect/Passes.h"
17+
#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
#include "llvm/ADT/StringMap.h"
20+
#include "llvm/ADT/StringRef.h"
21+
#include "llvm/ADT/Twine.h"
22+
#include "llvm/Support/ErrorHandling.h"
23+
#include "llvm/Support/Path.h"
24+
25+
#ifndef DIALECT_CIR_TRANSFORMS_STDHELPERS_H_
26+
#define DIALECT_CIR_TRANSFORMS_STDHELPERS_H_
27+
28+
namespace mlir {
29+
namespace cir {
30+
31+
bool isStdArrayType(mlir::Type t);
32+
33+
} // namespace cir
34+
} // namespace mlir
35+
36+
#endif
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -I%S/../Inputs -clangir-disable-emit-cxx-default -fclangir-enable -fclangir-idiom-recognizer -fclangir-lib-opt -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
#include "std-cxx.h"
5+
6+
int test_find(unsigned char n = 3)
7+
{
8+
unsigned num_found = 0;
9+
// CHECK: %[[pattern_addr:.*]] = cir.alloca !u8i, cir.ptr <!u8i>, ["n"
10+
std::array<unsigned char, 9> v = {1, 2, 3, 4, 5, 6, 7, 8, 9};
11+
12+
auto f = std::find(v.begin(), v.end(), n);
13+
// CHECK: %[[begin:.*]] = cir.call @_ZNSt5arrayIhLj9EE5beginEv
14+
// CHECK: cir.call @_ZNSt5arrayIhLj9EE3endEv
15+
// CHECK: %[[cast_to_void:.*]] = cir.cast(bitcast, %[[begin]] : !cir.ptr<!u8i>), !cir.ptr<!void>
16+
// CHECK: %[[load_pattern:.*]] = cir.load %[[pattern_addr]] : cir.ptr <!u8i>, !u8i
17+
// CHECK: %[[pattern:.*]] = cir.cast(integral, %[[load_pattern:.*]] : !u8i), !s32i
18+
19+
// CHECK-NOT: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
20+
// CHECK: %[[array_size:.*]] = cir.const(#cir.int<9> : !u64i) : !u64i
21+
22+
// CHECK: %[[result_cast:.*]] = cir.libc.memchr(%[[cast_to_void]], %[[pattern]], %[[array_size]])
23+
// CHECK: cir.cast(bitcast, %[[result_cast]] : !cir.ptr<!void>), !cir.ptr<!u8i>
24+
if (f != v.end())
25+
num_found++;
26+
27+
return num_found;
28+
}

0 commit comments

Comments
 (0)