Skip to content

[SandboxIR] Implement GlobalIFunc #108622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class DSOLocalEquivalent;
class ConstantTokenNone;
class GlobalValue;
class GlobalObject;
class GlobalIFunc;
class Context;
class Function;
class Instruction;
Expand Down Expand Up @@ -332,6 +333,7 @@ class Value {
friend class GlobalValue; // For `Val`.
friend class DSOLocalEquivalent; // For `Val`.
friend class GlobalObject; // For `Val`.
friend class GlobalIFunc; // For `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -1128,6 +1130,7 @@ class GlobalValue : public Constant {
friend class Context; // For constructor.

public:
using LinkageTypes = llvm::GlobalValue::LinkageTypes;
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
switch (From->getSubclassID()) {
Expand Down Expand Up @@ -1285,6 +1288,88 @@ class GlobalObject : public GlobalValue {
}
};

/// Provides API functions, like getIterator() and getReverseIterator() to
/// GlobalIFunc, Function, GlobalVariable and GlobalAlias. In LLVM IR these are
/// provided by ilist_node.
template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
typename LLVMParentT>
class GlobalWithNodeAPI : public ParentT {
/// Helper for mapped_iterator.
struct LLVMGVToGV {
Context &Ctx;
LLVMGVToGV(Context &Ctx) : Ctx(Ctx) {}
GlobalT &operator()(LLVMGlobalT &LLVMGV) const;
};

public:
GlobalWithNodeAPI(Value::ClassID ID, LLVMParentT *C, Context &Ctx)
: ParentT(ID, C, Ctx) {}

// TODO: Missing getParent(). Should be added once Module is available.

using iterator = mapped_iterator<
decltype(static_cast<LLVMGlobalT *>(nullptr)->getIterator()), LLVMGVToGV>;
using reverse_iterator = mapped_iterator<
decltype(static_cast<LLVMGlobalT *>(nullptr)->getReverseIterator()),
LLVMGVToGV>;
iterator getIterator() const {
auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
LLVMGVToGV ToGV(this->Ctx);
return map_iterator(LLVMGV->getIterator(), ToGV);
}
reverse_iterator getReverseIterator() const {
auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
LLVMGVToGV ToGV(this->Ctx);
return map_iterator(LLVMGV->getReverseIterator(), ToGV);
}
};

class GlobalIFunc final
: public GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
llvm::GlobalObject> {
GlobalIFunc(llvm::GlobalObject *C, Context &Ctx)
: GlobalWithNodeAPI(ClassID::GlobalIFunc, C, Ctx) {}
friend class Context; // For constructor.

public:
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
return From->getSubclassID() == ClassID::GlobalIFunc;
}

// TODO: Missing create() because we don't have a sandboxir::Module yet.

// TODO: Missing functions: copyAttributesFrom(), removeFromParent(),
// eraseFromParent()

void setResolver(Constant *Resolver);

Constant *getResolver() const;

// Return the resolver function after peeling off potential ConstantExpr
// indirection.
Function *getResolverFunction();
const Function *getResolverFunction() const {
return const_cast<GlobalIFunc *>(this)->getResolverFunction();
}

static bool isValidLinkage(LinkageTypes L) {
return llvm::GlobalIFunc::isValidLinkage(L);
}

// TODO: Missing applyAlongResolverPath().

#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::GlobalIFunc>(Val) && "Expected a GlobalIFunc!");
}
void dumpOS(raw_ostream &OS) const override {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}
#endif
};

class BlockAddress final : public Constant {
BlockAddress(llvm::BlockAddress *C, Context &Ctx)
: Constant(ClassID::BlockAddress, C, Ctx) {}
Expand Down Expand Up @@ -4219,7 +4304,8 @@ class Context {
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};

class Function : public GlobalObject {
class Function : public GlobalWithNodeAPI<Function, llvm::Function,
GlobalObject, llvm::GlobalObject> {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
Expand All @@ -4230,7 +4316,7 @@ class Function : public GlobalObject {
};
/// Use Context::createFunction() instead.
Function(llvm::Function *F, sandboxir::Context &Ctx)
: GlobalObject(ClassID::Function, F, Ctx) {}
: GlobalWithNodeAPI(ClassID::Function, F, Ctx) {}
friend class Context; // For constructor.

public:
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,39 @@ void GlobalObject::setSection(StringRef S) {
cast<llvm::GlobalObject>(Val)->setSection(S);
}

template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
typename LLVMParentT>
GlobalT &GlobalWithNodeAPI<GlobalT, LLVMGlobalT, ParentT, LLVMParentT>::
LLVMGVToGV::operator()(LLVMGlobalT &LLVMGV) const {
return cast<GlobalT>(*Ctx.getValue(&LLVMGV));
}

namespace llvm::sandboxir {
// Explicit instantiations.
template class GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
llvm::GlobalObject>;
template class GlobalWithNodeAPI<Function, llvm::Function, GlobalObject,
llvm::GlobalObject>;
} // namespace llvm::sandboxir

void GlobalIFunc::setResolver(Constant *Resolver) {
Ctx.getTracker()
.emplaceIfTracking<
GenericSetter<&GlobalIFunc::getResolver, &GlobalIFunc::setResolver>>(
this);
cast<llvm::GlobalIFunc>(Val)->setResolver(
cast<llvm::Constant>(Resolver->Val));
}

Constant *GlobalIFunc::getResolver() const {
return Ctx.getOrCreateConstant(cast<llvm::GlobalIFunc>(Val)->getResolver());
}

Function *GlobalIFunc::getResolverFunction() {
return cast<Function>(Ctx.getOrCreateConstant(
cast<llvm::GlobalIFunc>(Val)->getResolverFunction()));
}

void GlobalValue::setUnnamedAddr(UnnamedAddr V) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&GlobalValue::getUnnamedAddr,
Expand Down Expand Up @@ -2727,6 +2760,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<Function>(
new Function(cast<llvm::Function>(C), *this));
break;
case llvm::Value::GlobalIFuncVal:
It->second = std::unique_ptr<GlobalIFunc>(
new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
break;
default:
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
break;
Expand Down
137 changes: 122 additions & 15 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,84 @@ define void @foo() {
EXPECT_EQ(GO->canIncreaseAlignment(), LLVMGO->canIncreaseAlignment());
}

TEST_F(SandboxIRTest, GlobalIFunc) {
parseIR(C, R"IR(
declare external void @bar()
@ifunc0 = ifunc void(), ptr @foo
@ifunc1 = ifunc void(), ptr @foo
define void @foo() {
call void @ifunc0()
call void @ifunc1()
call void @bar()
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
auto *LLVMBB = &*LLVMF.begin();
auto LLVMIt = LLVMBB->begin();
auto *LLVMCall0 = cast<llvm::CallInst>(&*LLVMIt++);
auto *LLVMIFunc0 = cast<llvm::GlobalIFunc>(LLVMCall0->getCalledOperand());

sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call0 = cast<sandboxir::CallInst>(&*It++);
auto *Call1 = cast<sandboxir::CallInst>(&*It++);
auto *CallBar = cast<sandboxir::CallInst>(&*It++);
// Check classof(), creation.
auto *IFunc0 = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
auto *IFunc1 = cast<sandboxir::GlobalIFunc>(Call1->getCalledOperand());
auto *Bar = cast<sandboxir::Function>(CallBar->getCalledOperand());

// Check getIterator().
{
auto It0 = IFunc0->getIterator();
auto It1 = IFunc1->getIterator();
EXPECT_EQ(&*It0, IFunc0);
EXPECT_EQ(&*It1, IFunc1);
EXPECT_EQ(std::next(It0), It1);
EXPECT_EQ(std::prev(It1), It0);
EXPECT_EQ(&*std::next(It0), IFunc1);
EXPECT_EQ(&*std::prev(It1), IFunc0);
}
// Check getReverseIterator().
{
auto RevIt0 = IFunc0->getReverseIterator();
auto RevIt1 = IFunc1->getReverseIterator();
EXPECT_EQ(&*RevIt0, IFunc0);
EXPECT_EQ(&*RevIt1, IFunc1);
EXPECT_EQ(std::prev(RevIt0), RevIt1);
EXPECT_EQ(std::next(RevIt1), RevIt0);
EXPECT_EQ(&*std::prev(RevIt0), IFunc1);
EXPECT_EQ(&*std::next(RevIt1), IFunc0);
}

// Check setResolver(), getResolver().
EXPECT_EQ(IFunc0->getResolver(), Ctx.getValue(LLVMIFunc0->getResolver()));
auto *OrigResolver = IFunc0->getResolver();
auto *NewResolver = Bar;
EXPECT_NE(NewResolver, OrigResolver);
IFunc0->setResolver(NewResolver);
EXPECT_EQ(IFunc0->getResolver(), NewResolver);
IFunc0->setResolver(OrigResolver);
EXPECT_EQ(IFunc0->getResolver(), OrigResolver);
// Check getResolverFunction().
EXPECT_EQ(IFunc0->getResolverFunction(),
Ctx.getValue(LLVMIFunc0->getResolverFunction()));
// Check isValidLinkage().
for (auto L :
{GlobalValue::ExternalLinkage, GlobalValue::AvailableExternallyLinkage,
GlobalValue::LinkOnceAnyLinkage, GlobalValue::LinkOnceODRLinkage,
GlobalValue::WeakAnyLinkage, GlobalValue::WeakODRLinkage,
GlobalValue::AppendingLinkage, GlobalValue::InternalLinkage,
GlobalValue::PrivateLinkage, GlobalValue::ExternalWeakLinkage,
GlobalValue::CommonLinkage}) {
EXPECT_EQ(IFunc0->isValidLinkage(L), LLVMIFunc0->isValidLinkage(L));
}
}

TEST_F(SandboxIRTest, BlockAddress) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
Expand Down Expand Up @@ -1200,29 +1278,58 @@ define void @foo(i8 %v) {

TEST_F(SandboxIRTest, Function) {
parseIR(C, R"IR(
define void @foo(i32 %arg0, i32 %arg1) {
define void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1
bb1:
ret void
}
define void @foo1() {
ret void
}

)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
llvm::Argument *LLVMArg0 = LLVMF->getArg(0);
llvm::Argument *LLVMArg1 = LLVMF->getArg(1);
llvm::Function *LLVMF0 = &*M->getFunction("foo0");
llvm::Function *LLVMF1 = &*M->getFunction("foo1");
llvm::Argument *LLVMArg0 = LLVMF0->getArg(0);
llvm::Argument *LLVMArg1 = LLVMF0->getArg(1);

sandboxir::Context Ctx(C);
sandboxir::Function *F = Ctx.createFunction(LLVMF);
sandboxir::Function *F0 = Ctx.createFunction(LLVMF0);
sandboxir::Function *F1 = Ctx.createFunction(LLVMF1);

// Check getIterator().
{
auto It0 = F0->getIterator();
auto It1 = F1->getIterator();
EXPECT_EQ(&*It0, F0);
EXPECT_EQ(&*It1, F1);
EXPECT_EQ(std::next(It0), It1);
EXPECT_EQ(std::prev(It1), It0);
EXPECT_EQ(&*std::next(It0), F1);
EXPECT_EQ(&*std::prev(It1), F0);
}
// Check getReverseIterator().
{
auto RevIt0 = F0->getReverseIterator();
auto RevIt1 = F1->getReverseIterator();
EXPECT_EQ(&*RevIt0, F0);
EXPECT_EQ(&*RevIt1, F1);
EXPECT_EQ(std::prev(RevIt0), RevIt1);
EXPECT_EQ(std::next(RevIt1), RevIt0);
EXPECT_EQ(&*std::prev(RevIt0), F1);
EXPECT_EQ(&*std::next(RevIt1), F0);
}

// Check F arguments
EXPECT_EQ(F->arg_size(), 2u);
EXPECT_FALSE(F->arg_empty());
EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0));
EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1));
EXPECT_EQ(F0->arg_size(), 2u);
EXPECT_FALSE(F0->arg_empty());
EXPECT_EQ(F0->getArg(0), Ctx.getValue(LLVMArg0));
EXPECT_EQ(F0->getArg(1), Ctx.getValue(LLVMArg1));

// Check F.begin(), F.end(), Function::iterator
llvm::BasicBlock *LLVMBB = &*LLVMF->begin();
for (sandboxir::BasicBlock &BB : *F) {
llvm::BasicBlock *LLVMBB = &*LLVMF0->begin();
for (sandboxir::BasicBlock &BB : *F0) {
EXPECT_EQ(&BB, Ctx.getValue(LLVMBB));
LLVMBB = LLVMBB->getNextNode();
}
Expand All @@ -1232,17 +1339,17 @@ define void @foo(i32 %arg0, i32 %arg1) {
// Check F.dumpNameAndArgs()
std::string Buff;
raw_string_ostream BS(Buff);
F->dumpNameAndArgs(BS);
EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)");
F0->dumpNameAndArgs(BS);
EXPECT_EQ(Buff, "void @foo0(i32 %arg0, i32 %arg1)");
}
{
// Check F.dump()
std::string Buff;
raw_string_ostream BS(Buff);
BS << "\n";
F->dumpOS(BS);
F0->dumpOS(BS);
EXPECT_EQ(Buff, R"IR(
void @foo(i32 %arg0, i32 %arg1) {
void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1 ; SB4. (Br)

Expand Down
32 changes: 32 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,38 @@ define void @foo() {
EXPECT_EQ(GV->getVisibility(), OrigVisibility);
}

TEST_F(TrackerTest, GlobalIFuncSetters) {
parseIR(C, R"IR(
declare external void @bar()
@ifunc = ifunc void(), ptr @foo
define void @foo() {
call void @ifunc()
call void @bar()
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call0 = cast<sandboxir::CallInst>(&*It++);
auto *Call1 = cast<sandboxir::CallInst>(&*It++);
// Check classof(), creation.
auto *IFunc = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
auto *Bar = cast<sandboxir::Function>(Call1->getCalledOperand());
// Check setResolver().
auto *OrigResolver = IFunc->getResolver();
auto *NewResolver = Bar;
EXPECT_NE(NewResolver, OrigResolver);
Ctx.save();
IFunc->setResolver(NewResolver);
EXPECT_EQ(IFunc->getResolver(), NewResolver);
Ctx.revert();
EXPECT_EQ(IFunc->getResolver(), OrigResolver);
}

TEST_F(TrackerTest, SetVolatile) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, i8 %val) {
Expand Down
Loading