Skip to content

Commit cc4845a

Browse files
authored
[SYCL][ESIMD] Pass to replace simd* parameters with native llvm vectors. (#2097)
* [SYCL][ESIMD] Pass to replace simd* parameters with native llvm vectors. This pass is needed for the ESIMD backend to generate correct code. * Add test for ESIMDLowerVecArg pass. Refactor code as per code comments. Signed-off-by: Konstantin S Bobrovsky <[email protected]> Signed-off-by: Ashar, Pratik J <[email protected]>
1 parent ac3de67 commit cc4845a

File tree

8 files changed

+600
-0
lines changed

8 files changed

+600
-0
lines changed

llvm/include/llvm/InitializePasses.h

+1
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ void initializeStructurizeCFGPass(PassRegistry&);
418418
void initializeSYCLLowerWGScopeLegacyPassPass(PassRegistry &);
419419
void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
420420
void initializeESIMDLowerLoadStorePass(PassRegistry &);
421+
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
421422
void initializeTailCallElimPass(PassRegistry&);
422423
void initializeTailDuplicatePass(PassRegistry&);
423424
void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&);

llvm/include/llvm/LinkAllPasses.h

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ namespace {
204204
(void)llvm::createSYCLLowerWGScopePass();
205205
(void)llvm::createSYCLLowerESIMDPass();
206206
(void)llvm::createESIMDLowerLoadStorePass();
207+
(void)llvm::createESIMDLowerVecArgPass();
207208
std::string buf;
208209
llvm::raw_string_ostream os(buf);
209210
(void) llvm::createPrintModulePass(os);

llvm/include/llvm/SYCLLowerIR/LowerESIMD.h

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class ESIMDLowerLoadStorePass : public PassInfoMixin<ESIMDLowerLoadStorePass> {
4242
FunctionPass *createESIMDLowerLoadStorePass();
4343
void initializeESIMDLowerLoadStorePass(PassRegistry &);
4444

45+
ModulePass *createESIMDLowerVecArgPass();
46+
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
47+
4548
} // namespace llvm
4649

4750
#endif // LLVM_SYCLLOWERIR_LOWERESIMD_H

llvm/lib/SYCLLowerIR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
3535
LowerWGScope.cpp
3636
LowerESIMD.cpp
3737
LowerESIMDVLoadVStore.cpp
38+
LowerESIMDVecArg.cpp
3839

3940
ADDITIONAL_HEADER_DIRS
4041
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
+320
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
//===-- ESIMDVecArgPass.cpp - lower Close To Metal (CM) constructs --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Change in function parameter type from simd* to native llvm vector type for
9+
// cmc compiler to generate correct code for subroutine parameter passing and
10+
// globals:
11+
//
12+
// Old IR:
13+
// ======
14+
//
15+
// Parameter %0 is of type simd*
16+
// define dso_local spir_func void @_Z3fooPiN2cm3gen4simdIiLi16EEE(i32
17+
// addrspace(4)* %C,
18+
// "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" * %0)
19+
// local_unnamed_addr #2 {
20+
//
21+
// New IR:
22+
// ======
23+
//
24+
// Translate simd* parameter (#1) to vector <16 x 32>* type and insert bitcast.
25+
// All users of old parameter will use result of the bitcast.
26+
//
27+
// define dso_local spir_func void @_Z3fooPiN2cm3gen4simdIiLi16EEE(i32
28+
// addrspace(4)* %C,
29+
// <16 x i32>* %0) local_unnamed_addr #2 {
30+
// entry:
31+
// % 1 = bitcast<16 x i32> * % 0 to %
32+
// "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" *
33+
//
34+
//
35+
// Change in global variables:
36+
//
37+
// Old IR:
38+
// ======
39+
// @vc = global %"class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd"
40+
// zeroinitializer, align 64 #0
41+
//
42+
// % call.cm.i.i = tail call<16 x i32> @llvm.genx.vload.v16i32.p4v16i32(
43+
// <16 x i32> addrspace(4) * getelementptr(
44+
// % "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd",
45+
// % "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" addrspace(4) *
46+
// addrspacecast(% "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" * @vc to
47+
// % "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" addrspace(4) *), i64 0,
48+
// i32 0))
49+
//
50+
// New IR:
51+
// ======
52+
//
53+
// @0 = dso_local global <16 x i32> zeroinitializer, align 64 #0 <-- New Global
54+
// Variable
55+
//
56+
// % call.cm.i.i = tail call<16 x i32> @llvm.genx.vload.v16i32.p4v16i32(
57+
// <16 x i32> addrspace(4) * getelementptr(
58+
// % "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd",
59+
// % "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" addrspace(4) *
60+
// addrspacecast(% "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" *
61+
// bitcast(<16 x i32> * @0 to
62+
// %"class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" *) to %
63+
// "class._ZTSN2cm3gen4simdIiLi16EEE.cm::gen::simd" addrspace(4) *),
64+
// i64 0, i32 0))
65+
//===----------------------------------------------------------------------===//
66+
67+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
68+
#include "llvm/Transforms/Utils/Cloning.h"
69+
70+
using namespace llvm;
71+
72+
#define DEBUG_TYPE "ESIMDLowerVecArg"
73+
74+
namespace llvm {
75+
76+
// Forward declarations
77+
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
78+
ModulePass *createESIMDLowerVecArgPass();
79+
80+
// Pass converts simd* function parameters and globals to
81+
// llvm's first-class vector* type.
82+
class ESIMDLowerVecArgPass {
83+
public:
84+
bool run(Module &M);
85+
86+
private:
87+
DenseMap<GlobalVariable *, GlobalVariable *> OldNewGlobal;
88+
89+
Function *rewriteFunc(Function &F);
90+
Type *getSimdArgPtrTyOrNull(Value *arg);
91+
void fixGlobals(Module &M);
92+
void replaceConstExprWithGlobals(Module &M);
93+
ConstantExpr *createNewConstantExpr(GlobalVariable *newGlobalVar,
94+
Type *oldGlobalType, Value *old);
95+
void removeOldGlobals();
96+
};
97+
98+
} // namespace llvm
99+
100+
namespace {
101+
class ESIMDLowerVecArgLegacyPass : public ModulePass {
102+
public:
103+
static char ID;
104+
ESIMDLowerVecArgLegacyPass() : ModulePass(ID) {
105+
initializeESIMDLowerVecArgLegacyPassPass(*PassRegistry::getPassRegistry());
106+
}
107+
108+
bool runOnModule(Module &M) override {
109+
auto Modified = Impl.run(M);
110+
return Modified;
111+
}
112+
113+
bool doInitialization(Module &M) override { return false; }
114+
115+
private:
116+
ESIMDLowerVecArgPass Impl;
117+
};
118+
} // namespace
119+
120+
char ESIMDLowerVecArgLegacyPass::ID = 0;
121+
INITIALIZE_PASS(ESIMDLowerVecArgLegacyPass, "ESIMDLowerVecArg",
122+
"Translate simd ptr to native vector type", false, false)
123+
124+
// Public interface to VecArgPass
125+
ModulePass *llvm::createESIMDLowerVecArgPass() {
126+
return new ESIMDLowerVecArgLegacyPass();
127+
}
128+
129+
// Return ptr to first-class vector type if Value is a simd*, else return
130+
// nullptr.
131+
Type *ESIMDLowerVecArgPass::getSimdArgPtrTyOrNull(Value *arg) {
132+
auto ArgType = dyn_cast<PointerType>(arg->getType());
133+
if (!ArgType || !ArgType->getElementType()->isStructTy())
134+
return nullptr;
135+
auto ContainedType = ArgType->getElementType();
136+
if ((ContainedType->getStructNumElements() != 1) ||
137+
!ContainedType->getStructElementType(0)->isVectorTy())
138+
return nullptr;
139+
return PointerType::get(ContainedType->getStructElementType(0),
140+
ArgType->getPointerAddressSpace());
141+
}
142+
143+
// F may have multiple arguments of type simd*. This
144+
// function updates all parameters along with call
145+
// call sites of F.
146+
Function *ESIMDLowerVecArgPass::rewriteFunc(Function &F) {
147+
FunctionType *FTy = F.getFunctionType();
148+
Type *RetTy = FTy->getReturnType();
149+
SmallVector<Type *, 8> ArgTys;
150+
151+
for (unsigned int i = 0; i != F.arg_size(); i++) {
152+
auto Arg = F.getArg(i);
153+
Type *NewTy = getSimdArgPtrTyOrNull(Arg);
154+
if (NewTy) {
155+
// Copy over byval type for simd* type
156+
ArgTys.push_back(NewTy);
157+
} else {
158+
// Transfer all non-simd ptr arguments
159+
ArgTys.push_back(Arg->getType());
160+
}
161+
}
162+
163+
FunctionType *NFTy = FunctionType::get(RetTy, ArgTys, false);
164+
165+
// Create new function body and insert into the module
166+
Function *NF = Function::Create(NFTy, F.getLinkage(), F.getName());
167+
F.getParent()->getFunctionList().insert(F.getIterator(), NF);
168+
169+
SmallVector<ReturnInst *, 8> Returns;
170+
SmallVector<BitCastInst *, 8> BitCasts;
171+
ValueToValueMapTy VMap;
172+
for (unsigned int I = 0; I != F.arg_size(); I++) {
173+
auto Arg = F.getArg(I);
174+
Type *newTy = getSimdArgPtrTyOrNull(Arg);
175+
if (newTy) {
176+
// bitcast vector* -> simd*
177+
auto BitCast = new BitCastInst(NF->getArg(I), Arg->getType());
178+
BitCasts.push_back(BitCast);
179+
VMap.insert(std::make_pair(Arg, BitCast));
180+
continue;
181+
}
182+
VMap.insert(std::make_pair(Arg, NF->getArg(I)));
183+
}
184+
185+
llvm::CloneFunctionInto(NF, &F, VMap, F.getSubprogram() != nullptr, Returns);
186+
187+
for (auto &B : BitCasts) {
188+
NF->begin()->getInstList().push_front(B);
189+
}
190+
191+
NF->takeName(&F);
192+
193+
// Fix call sites
194+
SmallVector<std::pair<Instruction *, Instruction *>, 10> OldNewInst;
195+
for (auto &use : F.uses()) {
196+
// Use must be a call site
197+
SmallVector<Value *, 10> Params;
198+
auto Call = cast<CallInst>(use.getUser());
199+
// Variadic functions not supported
200+
assert(!Call->getFunction()->isVarArg() &&
201+
"Variadic functions not supported");
202+
for (unsigned int I = 0; I < Call->getNumArgOperands(); I++) {
203+
auto SrcOpnd = Call->getOperand(I);
204+
auto NewTy = getSimdArgPtrTyOrNull(SrcOpnd);
205+
if (NewTy) {
206+
auto BitCast = new BitCastInst(SrcOpnd, NewTy, "", Call);
207+
Params.push_back(BitCast);
208+
} else {
209+
if (SrcOpnd != &F)
210+
Params.push_back(SrcOpnd);
211+
else
212+
Params.push_back(NF);
213+
}
214+
}
215+
// create new call instruction
216+
auto NewCallInst = CallInst::Create(NFTy, NF, Params, "");
217+
NewCallInst->setCallingConv(F.getCallingConv());
218+
OldNewInst.push_back(std::make_pair(Call, NewCallInst));
219+
}
220+
221+
for (auto InstPair : OldNewInst) {
222+
auto OldInst = InstPair.first;
223+
auto NewInst = InstPair.second;
224+
ReplaceInstWithInst(OldInst, NewInst);
225+
}
226+
227+
F.eraseFromParent();
228+
229+
return NF;
230+
}
231+
232+
// Replace ConstantExpr if it contains old global variable.
233+
ConstantExpr *
234+
ESIMDLowerVecArgPass::createNewConstantExpr(GlobalVariable *NewGlobalVar,
235+
Type *OldGlobalType, Value *Old) {
236+
ConstantExpr *NewConstantExpr = nullptr;
237+
238+
if (isa<GlobalVariable>(Old)) {
239+
NewConstantExpr = cast<ConstantExpr>(
240+
ConstantExpr::getBitCast(NewGlobalVar, OldGlobalType));
241+
return NewConstantExpr;
242+
}
243+
244+
auto InnerMost = createNewConstantExpr(
245+
NewGlobalVar, OldGlobalType, cast<ConstantExpr>(Old)->getOperand(0));
246+
247+
NewConstantExpr = cast<ConstantExpr>(
248+
cast<ConstantExpr>(Old)->getWithOperandReplaced(0, InnerMost));
249+
250+
return NewConstantExpr;
251+
}
252+
253+
// Globals are part of ConstantExpr. This loop iterates over
254+
// all such instances and replaces them with a new ConstantExpr
255+
// consisting of new global vector* variable.
256+
void ESIMDLowerVecArgPass::replaceConstExprWithGlobals(Module &M) {
257+
for (auto &GlobalVars : OldNewGlobal) {
258+
auto &G = *GlobalVars.first;
259+
for (auto UseOfG : G.users()) {
260+
auto NewGlobal = GlobalVars.second;
261+
auto NewConstExpr = createNewConstantExpr(NewGlobal, G.getType(), UseOfG);
262+
UseOfG->replaceAllUsesWith(NewConstExpr);
263+
}
264+
}
265+
}
266+
267+
// This function creates new global variables of type vector* type
268+
// when old one is of simd* type.
269+
void ESIMDLowerVecArgPass::fixGlobals(Module &M) {
270+
for (auto &G : M.getGlobalList()) {
271+
auto NewTy = getSimdArgPtrTyOrNull(&G);
272+
if (NewTy && !G.user_empty()) {
273+
// Peel off ptr type that getSimdArgPtrTyOrNull applies
274+
NewTy = NewTy->getPointerElementType();
275+
auto ZeroInit = ConstantAggregateZero::get(NewTy);
276+
auto NewGlobalVar =
277+
new GlobalVariable(NewTy, G.isConstant(), G.getLinkage(), ZeroInit,
278+
"", G.getThreadLocalMode(), G.getAddressSpace());
279+
NewGlobalVar->setExternallyInitialized(G.isExternallyInitialized());
280+
NewGlobalVar->copyAttributesFrom(&G);
281+
NewGlobalVar->takeName(&G);
282+
NewGlobalVar->copyMetadata(&G, 0);
283+
M.getGlobalList().push_back(NewGlobalVar);
284+
OldNewGlobal.insert(std::make_pair(&G, NewGlobalVar));
285+
}
286+
}
287+
288+
replaceConstExprWithGlobals(M);
289+
290+
removeOldGlobals();
291+
}
292+
293+
// Remove old global variables from the program.
294+
void ESIMDLowerVecArgPass::removeOldGlobals() {
295+
for (auto &G : OldNewGlobal) {
296+
G.first->removeDeadConstantUsers();
297+
G.first->eraseFromParent();
298+
}
299+
}
300+
301+
bool ESIMDLowerVecArgPass::run(Module &M) {
302+
fixGlobals(M);
303+
304+
SmallVector<Function *, 10> functions;
305+
for (auto &F : M) {
306+
functions.push_back(&F);
307+
}
308+
309+
for (auto F : functions) {
310+
for (unsigned int I = 0; I != F->arg_size(); I++) {
311+
auto Arg = F->getArg(I);
312+
if (getSimdArgPtrTyOrNull(Arg)) {
313+
rewriteFunc(*F);
314+
break;
315+
}
316+
}
317+
}
318+
319+
return true;
320+
}

0 commit comments

Comments
 (0)