Skip to content

Commit 2303e93

Browse files
sommerlukasrotateright
authored andcommitted
[Codegen][ReplaceWithVecLib] add pass to replace vector intrinsics with calls to vector library
This patch adds a pass to replace calls to vector intrinsics (i.e., LLVM intrinsics operating on vector operands) with calls to a vector library. Currently, calls to LLVM intrinsics are only replaced with calls to vector libraries when scalar calls to intrinsics are vectorized by the Loop- or SLP-Vectorizer. With this pass, it is now possible to replace calls to LLVM intrinsics already operating on vector operands, e.g., if such code was generated by MLIR. For the replacement, information from the TargetLibraryInfo, e.g., as specified via -vector-library is used. Differential Revision: https://reviews.llvm.org/D95373
1 parent e3c0b0f commit 2303e93

File tree

15 files changed

+422
-0
lines changed

15 files changed

+422
-0
lines changed

llvm/include/llvm/CodeGen/CodeGenPassBuilder.h

+7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/CodeGen/MachineModuleInfo.h"
3030
#include "llvm/CodeGen/MachinePassManager.h"
3131
#include "llvm/CodeGen/PreISelIntrinsicLowering.h"
32+
#include "llvm/CodeGen/ReplaceWithVeclib.h"
3233
#include "llvm/CodeGen/UnreachableBlockElim.h"
3334
#include "llvm/IR/IRPrintingPasses.h"
3435
#include "llvm/IR/PassManager.h"
@@ -650,6 +651,12 @@ void CodeGenPassBuilder<Derived>::addIRPasses(AddIRPass &addPass) const {
650651
if (getOptLevel() != CodeGenOpt::None && !Opt.DisableConstantHoisting)
651652
addPass(ConstantHoistingPass());
652653

654+
if (getOptLevel() != CodeGenOpt::None) {
655+
// Replace calls to LLVM intrinsics (e.g., exp, log) operating on vector
656+
// operands with calls to the corresponding functions in a vector library.
657+
addPass(ReplaceWithVeclib());
658+
}
659+
653660
if (getOptLevel() != CodeGenOpt::None && !Opt.DisablePartialLibcallInlining)
654661
addPass(PartiallyInlineLibCallsPass());
655662

llvm/include/llvm/CodeGen/MachinePassRegistry.def

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ FUNCTION_PASS("mergeicmps", MergeICmpsPass, ())
3939
FUNCTION_PASS("lower-constant-intrinsics", LowerConstantIntrinsicsPass, ())
4040
FUNCTION_PASS("unreachableblockelim", UnreachableBlockElimPass, ())
4141
FUNCTION_PASS("consthoist", ConstantHoistingPass, ())
42+
FUNCTION_PASS("replace-with-veclib", ReplaceWithVeclib, ())
4243
FUNCTION_PASS("partially-inline-libcalls", PartiallyInlineLibCallsPass, ())
4344
FUNCTION_PASS("ee-instrument", EntryExitInstrumenterPass, (false))
4445
FUNCTION_PASS("post-inline-ee-instrument", EntryExitInstrumenterPass, (true))

llvm/include/llvm/CodeGen/Passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ namespace llvm {
448448
/// shuffles.
449449
FunctionPass *createExpandReductionsPass();
450450

451+
// This pass replaces intrinsics operating on vector operands with calls to
452+
// the corresponding function in a vector library (e.g., SVML, libmvec).
453+
FunctionPass *createReplaceWithVeclibLegacyPass();
454+
451455
// This pass expands memcmp() to load/stores.
452456
FunctionPass *createExpandMemCmpPass();
453457

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- ReplaceWithVeclib.h - Replace vector instrinsics with veclib calls -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
10+
// with vector operands) with matching calls to functions from a vector
11+
// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
#ifndef LLVM_TRANSFORMS_UTILS_REPLACEWITHVECLIB_H
15+
#define LLVM_TRANSFORMS_UTILS_REPLACEWITHVECLIB_H
16+
17+
#include "llvm/IR/PassManager.h"
18+
#include "llvm/InitializePasses.h"
19+
20+
namespace llvm {
21+
class ReplaceWithVeclib : public PassInfoMixin<ReplaceWithVeclib> {
22+
public:
23+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
24+
};
25+
26+
// Legacy pass
27+
class ReplaceWithVeclibLegacy : public FunctionPass {
28+
public:
29+
static char ID;
30+
ReplaceWithVeclibLegacy() : FunctionPass(ID) {
31+
initializeReplaceWithVeclibLegacyPass(*PassRegistry::getPassRegistry());
32+
}
33+
void getAnalysisUsage(AnalysisUsage &AU) const override;
34+
bool runOnFunction(Function &F) override;
35+
};
36+
37+
} // End namespace llvm
38+
#endif // LLVM_TRANSFORMS_UTILS_REPLACEWITHVECLIB_H

llvm/include/llvm/InitializePasses.h

+1
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ void initializeRegionPrinterPass(PassRegistry&);
380380
void initializeRegionViewerPass(PassRegistry&);
381381
void initializeRegisterCoalescerPass(PassRegistry&);
382382
void initializeRenameIndependentSubregsPass(PassRegistry&);
383+
void initializeReplaceWithVeclibLegacyPass(PassRegistry &);
383384
void initializeResetMachineFunctionPass(PassRegistry&);
384385
void initializeReversePostOrderFunctionAttrsLegacyPassPass(PassRegistry&);
385386
void initializeRewriteStatepointsForGCLegacyPassPass(PassRegistry &);

llvm/lib/CodeGen/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ add_llvm_component_library(LLVMCodeGen
147147
RegisterUsageInfo.cpp
148148
RegUsageInfoCollector.cpp
149149
RegUsageInfoPropagate.cpp
150+
ReplaceWithVeclib.cpp
150151
ResetMachineFunctionPass.cpp
151152
SafeStack.cpp
152153
SafeStackLayout.cpp
+256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//=== ReplaceWithVeclib.cpp - Replace vector instrinsics with veclib calls ===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
10+
// with vector operands) with matching calls to functions from a vector
11+
// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "llvm/CodeGen/ReplaceWithVeclib.h"
16+
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/Statistic.h"
18+
#include "llvm/Analysis/DemandedBits.h"
19+
#include "llvm/Analysis/GlobalsModRef.h"
20+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
21+
#include "llvm/Analysis/TargetLibraryInfo.h"
22+
#include "llvm/Analysis/VectorUtils.h"
23+
#include "llvm/CodeGen/Passes.h"
24+
#include "llvm/IR/IRBuilder.h"
25+
#include "llvm/IR/InstIterator.h"
26+
#include "llvm/IR/IntrinsicInst.h"
27+
#include "llvm/Transforms/Utils/ModuleUtils.h"
28+
29+
using namespace llvm;
30+
31+
#define DEBUG_TYPE "replace-with-veclib"
32+
33+
STATISTIC(NumCallsReplaced,
34+
"Number of calls to intrinsics that have been replaced.");
35+
36+
STATISTIC(NumTLIFuncDeclAdded,
37+
"Number of vector library function declarations added.");
38+
39+
STATISTIC(NumFuncUsedAdded,
40+
"Number of functions added to `llvm.compiler.used`");
41+
42+
static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
43+
Module *M = CI.getModule();
44+
45+
Function *OldFunc = CI.getCalledFunction();
46+
47+
// Check if the vector library function is already declared in this module,
48+
// otherwise insert it.
49+
Function *TLIFunc = M->getFunction(TLIName);
50+
if (!TLIFunc) {
51+
TLIFunc = Function::Create(OldFunc->getFunctionType(),
52+
Function::ExternalLinkage, TLIName, *M);
53+
TLIFunc->copyAttributesFrom(OldFunc);
54+
55+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
56+
<< TLIName << "` of type `" << *(TLIFunc->getType())
57+
<< "` to module.\n");
58+
59+
++NumTLIFuncDeclAdded;
60+
61+
// Add the freshly created function to llvm.compiler.used,
62+
// similar to as it is done in InjectTLIMappings
63+
appendToCompilerUsed(*M, {TLIFunc});
64+
65+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
66+
<< "` to `@llvm.compiler.used`.\n");
67+
++NumFuncUsedAdded;
68+
}
69+
70+
// Replace the call to the vector intrinsic with a call
71+
// to the corresponding function from the vector library.
72+
IRBuilder<> IRBuilder{&CI};
73+
SmallVector<Value *> Args(CI.arg_operands());
74+
// Preserve the operand bundles.
75+
SmallVector<OperandBundleDef, 1> OpBundles;
76+
CI.getOperandBundlesAsDefs(OpBundles);
77+
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
78+
assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
79+
"Expecting function types to be identical");
80+
CI.replaceAllUsesWith(Replacement);
81+
if (isa<FPMathOperator>(Replacement)) {
82+
// Preserve fast math flags for FP math.
83+
Replacement->copyFastMathFlags(&CI);
84+
}
85+
86+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
87+
<< OldFunc->getName() << "` with call to `" << TLIName
88+
<< "`.\n");
89+
++NumCallsReplaced;
90+
return true;
91+
}
92+
93+
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
94+
CallInst &CI) {
95+
if (!CI.getCalledFunction()) {
96+
return false;
97+
}
98+
99+
auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
100+
if (IntrinsicID == Intrinsic::not_intrinsic) {
101+
// Replacement is only performed for intrinsic functions
102+
return false;
103+
}
104+
105+
// Convert vector arguments to scalar type and check that
106+
// all vector operands have identical vector width.
107+
unsigned VF = 0;
108+
SmallVector<Type *> ScalarTypes;
109+
for (auto Arg : enumerate(CI.arg_operands())) {
110+
auto *ArgType = Arg.value()->getType();
111+
// Vector calls to intrinsics can still have
112+
// scalar operands for specific arguments.
113+
if (hasVectorInstrinsicScalarOpd(IntrinsicID, Arg.index())) {
114+
ScalarTypes.push_back(ArgType);
115+
} else {
116+
// The argument in this place should be a vector if
117+
// this is a call to a vector intrinsic.
118+
auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
119+
if (!VectorArgTy) {
120+
// The argument is not a vector, do not perform
121+
// the replacement.
122+
return false;
123+
}
124+
auto NumElements = VectorArgTy->getElementCount();
125+
if (NumElements.isScalable()) {
126+
// The current implementation does not support
127+
// scalable vectors.
128+
return false;
129+
}
130+
if (VF && VF != NumElements.getFixedValue()) {
131+
// The different arguments differ in vector size.
132+
return false;
133+
} else {
134+
VF = NumElements.getFixedValue();
135+
}
136+
ScalarTypes.push_back(VectorArgTy->getElementType());
137+
}
138+
}
139+
140+
// Try to reconstruct the name for the scalar version of this
141+
// intrinsic using the intrinsic ID and the argument types
142+
// converted to scalar above.
143+
std::string ScalarName;
144+
if (Intrinsic::isOverloaded(IntrinsicID)) {
145+
ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes);
146+
} else {
147+
ScalarName = Intrinsic::getName(IntrinsicID).str();
148+
}
149+
150+
if (!TLI.isFunctionVectorizable(ScalarName)) {
151+
// The TargetLibraryInfo does not contain a vectorized version of
152+
// the scalar function.
153+
return false;
154+
}
155+
156+
// Try to find the mapping for the scalar version of this intrinsic
157+
// and the exact vector width of the call operands in the
158+
// TargetLibraryInfo.
159+
const std::string TLIName =
160+
std::string(TLI.getVectorizedFunction(ScalarName, VF));
161+
162+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
163+
<< ScalarName << "` and vector width " << VF << ".\n");
164+
165+
if (!TLIName.empty()) {
166+
// Found the correct mapping in the TargetLibraryInfo,
167+
// replace the call to the intrinsic with a call to
168+
// the vector library function.
169+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
170+
<< "`.\n");
171+
return replaceWithTLIFunction(CI, TLIName);
172+
}
173+
174+
return false;
175+
}
176+
177+
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
178+
bool Changed = false;
179+
SmallVector<CallInst *> ReplacedCalls;
180+
for (auto &I : instructions(F)) {
181+
if (auto *CI = dyn_cast<CallInst>(&I)) {
182+
if (replaceWithCallToVeclib(TLI, *CI)) {
183+
ReplacedCalls.push_back(CI);
184+
Changed = true;
185+
}
186+
}
187+
}
188+
// Erase the calls to the intrinsics that have been replaced
189+
// with calls to the vector library.
190+
for (auto *CI : ReplacedCalls) {
191+
CI->eraseFromParent();
192+
}
193+
return Changed;
194+
}
195+
196+
////////////////////////////////////////////////////////////////////////////////
197+
// New pass manager implementation.
198+
////////////////////////////////////////////////////////////////////////////////
199+
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
200+
FunctionAnalysisManager &AM) {
201+
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
202+
auto Changed = runImpl(TLI, F);
203+
if (Changed) {
204+
PreservedAnalyses PA;
205+
PA.preserveSet<CFGAnalyses>();
206+
PA.preserve<TargetLibraryAnalysis>();
207+
PA.preserve<ScalarEvolutionAnalysis>();
208+
PA.preserve<AAManager>();
209+
PA.preserve<LoopAccessAnalysis>();
210+
PA.preserve<DemandedBitsAnalysis>();
211+
PA.preserve<OptimizationRemarkEmitterAnalysis>();
212+
PA.preserve<GlobalsAA>();
213+
return PA;
214+
} else {
215+
// The pass did not replace any calls, hence it preserves all analyses.
216+
return PreservedAnalyses::all();
217+
}
218+
}
219+
220+
////////////////////////////////////////////////////////////////////////////////
221+
// Legacy PM Implementation.
222+
////////////////////////////////////////////////////////////////////////////////
223+
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
224+
const TargetLibraryInfo &TLI =
225+
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
226+
return runImpl(TLI, F);
227+
}
228+
229+
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
230+
AU.setPreservesCFG();
231+
AU.addRequired<TargetLibraryInfoWrapperPass>();
232+
AU.addPreserved<TargetLibraryInfoWrapperPass>();
233+
AU.addPreserved<ScalarEvolutionWrapperPass>();
234+
AU.addPreserved<AAResultsWrapperPass>();
235+
AU.addPreserved<LoopAccessLegacyAnalysis>();
236+
AU.addPreserved<DemandedBitsWrapperPass>();
237+
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
238+
AU.addPreserved<GlobalsAAWrapperPass>();
239+
}
240+
241+
////////////////////////////////////////////////////////////////////////////////
242+
// Legacy Pass manager initialization
243+
////////////////////////////////////////////////////////////////////////////////
244+
char ReplaceWithVeclibLegacy::ID = 0;
245+
246+
INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
247+
"Replace intrinsics with calls to vector library", false,
248+
false)
249+
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
250+
INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
251+
"Replace intrinsics with calls to vector library", false,
252+
false)
253+
254+
FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
255+
return new ReplaceWithVeclibLegacy();
256+
}

llvm/lib/CodeGen/TargetPassConfig.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,9 @@ void TargetPassConfig::addIRPasses() {
858858
if (getOptLevel() != CodeGenOpt::None && !DisableConstantHoisting)
859859
addPass(createConstantHoistingPass());
860860

861+
if (getOptLevel() != CodeGenOpt::None)
862+
addPass(createReplaceWithVeclibLegacyPass());
863+
861864
if (getOptLevel() != CodeGenOpt::None && !DisablePartialLibcallInlining)
862865
addPass(createPartiallyInlineLibCallsPass());
863866

llvm/test/CodeGen/AArch64/O3-pipeline.ll

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
; CHECK-NEXT: Branch Probability Analysis
5555
; CHECK-NEXT: Block Frequency Analysis
5656
; CHECK-NEXT: Constant Hoisting
57+
; CHECK-NEXT: Replace intrinsics with calls to vector library
5758
; CHECK-NEXT: Partially inline calls to library functions
5859
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
5960
; CHECK-NEXT: Scalarize Masked Memory Intrinsics

llvm/test/CodeGen/ARM/O3-pipeline.ll

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
; CHECK-NEXT: Branch Probability Analysis
3535
; CHECK-NEXT: Block Frequency Analysis
3636
; CHECK-NEXT: Constant Hoisting
37+
; CHECK-NEXT: Replace intrinsics with calls to vector library
3738
; CHECK-NEXT: Partially inline calls to library functions
3839
; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining)
3940
; CHECK-NEXT: Scalarize Masked Memory Intrinsics

0 commit comments

Comments
 (0)