Skip to content

Commit c8a678b

Browse files
authored
[ctx_prof] Add analysis utility to fetch ID of a callsite (#104491)
This will be needed when maintaining the contextual profile for ICP or inlining - we'll need to first fetch the ID of a callsite, which is in an instrumentation instruction (intrinsic) preceding the callsite.
1 parent 3565332 commit c8a678b

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef LLVM_ANALYSIS_CTXPROFANALYSIS_H
1010
#define LLVM_ANALYSIS_CTXPROFANALYSIS_H
1111

12+
#include "llvm/IR/InstrTypes.h"
13+
#include "llvm/IR/IntrinsicInst.h"
1214
#include "llvm/IR/PassManager.h"
1315
#include "llvm/ProfileData/PGOCtxProfReader.h"
1416

@@ -82,6 +84,8 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
8284
using Result = PGOContextualProfile;
8385

8486
PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM);
87+
88+
static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB);
8589
};
8690

8791
class CtxProfAnalysisPrinterPass

llvm/lib/Analysis/CtxProfAnalysis.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,10 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
186186
OS << "\n";
187187
return PreservedAnalyses::all();
188188
}
189+
190+
InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
191+
while (auto *Prev = CB.getPrevNode())
192+
if (auto *IPC = dyn_cast<InstrProfCallsite>(Prev))
193+
return IPC;
194+
return nullptr;
195+
}

llvm/unittests/Analysis/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(ANALYSIS_TEST_SOURCES
2222
CFGTest.cpp
2323
CGSCCPassManagerTest.cpp
2424
ConstraintSystemTest.cpp
25+
CtxProfAnalysisTest.cpp
2526
DDGTest.cpp
2627
DomTreeUpdaterTest.cpp
2728
DXILResourceTest.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
//===--- CtxProfAnalysisTest.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Analysis/CtxProfAnalysis.h"
10+
#include "llvm/Analysis/BlockFrequencyInfo.h"
11+
#include "llvm/Analysis/BranchProbabilityInfo.h"
12+
#include "llvm/Analysis/CGSCCPassManager.h"
13+
#include "llvm/Analysis/LoopAnalysisManager.h"
14+
#include "llvm/AsmParser/Parser.h"
15+
#include "llvm/IR/Analysis.h"
16+
#include "llvm/IR/Module.h"
17+
#include "llvm/IR/PassInstrumentation.h"
18+
#include "llvm/IR/PassManager.h"
19+
#include "llvm/Passes/PassBuilder.h"
20+
#include "llvm/Support/SourceMgr.h"
21+
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
22+
#include "gmock/gmock.h"
23+
#include "gtest/gtest.h"
24+
25+
using namespace llvm;
26+
27+
namespace {
28+
29+
class CtxProfAnalysisTest : public testing::Test {
30+
static constexpr auto *IR = R"IR(
31+
declare void @bar()
32+
33+
define private void @foo(i32 %a, ptr %fct) #0 !guid !0 {
34+
%t = icmp eq i32 %a, 0
35+
br i1 %t, label %yes, label %no
36+
yes:
37+
call void %fct(i32 %a)
38+
br label %exit
39+
no:
40+
call void @bar()
41+
br label %exit
42+
exit:
43+
ret void
44+
}
45+
46+
define void @an_entrypoint(i32 %a) {
47+
%t = icmp eq i32 %a, 0
48+
br i1 %t, label %yes, label %no
49+
50+
yes:
51+
call void @foo(i32 1, ptr null)
52+
ret void
53+
no:
54+
ret void
55+
}
56+
57+
define void @another_entrypoint_no_callees(i32 %a) {
58+
%t = icmp eq i32 %a, 0
59+
br i1 %t, label %yes, label %no
60+
61+
yes:
62+
ret void
63+
no:
64+
ret void
65+
}
66+
67+
attributes #0 = { noinline }
68+
!0 = !{ i64 11872291593386833696 }
69+
)IR";
70+
71+
protected:
72+
LLVMContext C;
73+
PassBuilder PB;
74+
ModuleAnalysisManager MAM;
75+
FunctionAnalysisManager FAM;
76+
CGSCCAnalysisManager CGAM;
77+
LoopAnalysisManager LAM;
78+
std::unique_ptr<Module> M;
79+
80+
void SetUp() override {
81+
SMDiagnostic Err;
82+
M = parseAssemblyString(IR, Err, C);
83+
ASSERT_TRUE(!!M);
84+
}
85+
86+
public:
87+
CtxProfAnalysisTest() {
88+
PB.registerModuleAnalyses(MAM);
89+
PB.registerCGSCCAnalyses(CGAM);
90+
PB.registerFunctionAnalyses(FAM);
91+
PB.registerLoopAnalyses(LAM);
92+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
93+
}
94+
};
95+
96+
TEST_F(CtxProfAnalysisTest, GetCallsiteIDTest) {
97+
ModulePassManager MPM;
98+
MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
99+
EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
100+
auto *F = M->getFunction("foo");
101+
ASSERT_NE(F, nullptr);
102+
std::vector<uint32_t> InsValues;
103+
104+
for (auto &BB : *F)
105+
for (auto &I : BB)
106+
if (auto *CB = dyn_cast<CallBase>(&I)) {
107+
// Skip instrumentation inserted intrinsics.
108+
if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic())
109+
continue;
110+
auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB);
111+
ASSERT_NE(Ins, nullptr);
112+
InsValues.push_back(Ins->getIndex()->getZExtValue());
113+
}
114+
115+
EXPECT_THAT(InsValues, testing::ElementsAre(0, 1));
116+
}
117+
118+
TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) {
119+
auto *F = M->getFunction("foo");
120+
ASSERT_NE(F, nullptr);
121+
CallBase *FirstCall = nullptr;
122+
for (auto &BB : *F)
123+
for (auto &I : BB)
124+
if (auto *CB = dyn_cast<CallBase>(&I)) {
125+
if (CB->isIndirectCall() || !CB->getCalledFunction()->isIntrinsic()) {
126+
FirstCall = CB;
127+
break;
128+
}
129+
}
130+
ASSERT_NE(FirstCall, nullptr);
131+
auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*FirstCall);
132+
EXPECT_EQ(IndIns, nullptr);
133+
}
134+
135+
} // namespace

0 commit comments

Comments
 (0)