Skip to content

Commit ea5ef3f

Browse files
authored
Fix VectorExtract&Insert Dynamic with sycl half translation (#1387)
This is a patch to fix forward translation of VectorExtractDynamic with sret argument of SYCL half type & fix reverse translation of VectorInsertDynamic with component of SYCL half type by adding extra instructions.
1 parent 7f168b2 commit ea5ef3f

File tree

4 files changed

+269
-0
lines changed

4 files changed

+269
-0
lines changed

lib/SPIRV/SPIRVInternal.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,8 @@ bool isOCLImageType(llvm::Type *Ty, StringRef *Name = nullptr);
662662
/// type name as spirv.BaseTyName.Postfixes.
663663
bool isSPIRVType(llvm::Type *Ty, StringRef BaseTyName, StringRef *Postfix = 0);
664664

665+
bool isSYCLHalfType(llvm::Type *Ty);
666+
665667
/// Decorate a function name as __spirv_{Name}_
666668
std::string decorateSPIRVFunction(const std::string &S);
667669

@@ -766,6 +768,19 @@ void mutateFunction(
766768
BuiltinFuncMangleInfo *Mangle = nullptr, AttributeList *Attrs = nullptr,
767769
bool TakeName = true);
768770

771+
/// Mutate function by change the arguments & the return type.
772+
/// \param ArgMutate mutates the function arguments.
773+
/// \param RetMutate mutates the function return value.
774+
/// \param TakeName Take the original function's name if a new function with
775+
/// different type needs to be created.
776+
void mutateFunction(
777+
Function *F,
778+
std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
779+
ArgMutate,
780+
std::function<Instruction *(CallInst *)> RetMutate,
781+
BuiltinFuncMangleInfo *Mangle = nullptr, AttributeList *Attrs = nullptr,
782+
bool TakeName = true);
783+
769784
/// Add a call instruction at \p Pos.
770785
CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
771786
ArrayRef<Value *> Args, AttributeList *Attrs,

lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ class SPIRVRegularizeLLVMBase {
109109
void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic);
110110
void buildUMulWithOverflowFunc(Function *UMulFunc);
111111

112+
// For some cases Clang emits VectorExtractDynamic as:
113+
// void @_Z28__spirv_VectorExtractDynamic(<Ty>* sret(<Ty>), jointMatrix, idx);
114+
// Instead of:
115+
// <Ty> @_Z28__spirv_VectorExtractDynamic(JointMatrix, Idx);
116+
// And VectorInsertDynamic as:
117+
// @_Z27__spirv_VectorInsertDynamic(jointMatrix, <Ty>* byval(<Ty>), idx);
118+
// Instead of:
119+
// @_Z27__spirv_VectorInsertDynamic(jointMatrix, <Ty>, idx)
120+
// Need to add additional GEP, store and load instructions and mutate called
121+
// function to avoid translation failures
122+
void expandSYCLHalfUsing(Module *M);
123+
void expandVEDWithSYCLHalfSRetArg(Function *F);
124+
void expandVIDWithSYCLHalfByValComp(Function *F);
125+
112126
static std::string lowerLLVMIntrinsicName(IntrinsicInst *II);
113127
void adaptStructTypes(StructType *ST);
114128
static char ID;
@@ -326,6 +340,83 @@ void SPIRVRegularizeLLVMBase::lowerUMulWithOverflow(
326340
UMulIntrinsic->setCalledFunction(UMulFunc);
327341
}
328342

343+
void SPIRVRegularizeLLVMBase::expandVEDWithSYCLHalfSRetArg(Function *F) {
344+
auto Attrs = F->getAttributes();
345+
Attrs = Attrs.removeParamAttribute(F->getContext(), 0, Attribute::StructRet);
346+
std::string Name = F->getName().str();
347+
CallInst *OldCall = nullptr;
348+
mutateFunction(
349+
F,
350+
[=, &OldCall](CallInst *CI, std::vector<Value *> &Args, Type *&RetTy) {
351+
Args.erase(Args.begin());
352+
auto *SRetPtrTy = cast<PointerType>(CI->getOperand(0)->getType());
353+
auto *ET = SRetPtrTy->getPointerElementType();
354+
RetTy = cast<StructType>(ET)->getElementType(0);
355+
OldCall = CI;
356+
return Name;
357+
},
358+
[=, &OldCall](CallInst *NewCI) {
359+
IRBuilder<> Builder(OldCall);
360+
auto *SRetPtrTy = cast<PointerType>(OldCall->getOperand(0)->getType());
361+
auto *ET = SRetPtrTy->getPointerElementType();
362+
Value *Target = Builder.CreateStructGEP(ET, OldCall->getOperand(0), 0);
363+
return Builder.CreateStore(NewCI, Target);
364+
},
365+
nullptr, &Attrs, true);
366+
}
367+
368+
void SPIRVRegularizeLLVMBase::expandVIDWithSYCLHalfByValComp(Function *F) {
369+
auto Attrs = F->getAttributes();
370+
Attrs = Attrs.removeParamAttribute(F->getContext(), 1, Attribute::ByVal);
371+
std::string Name = F->getName().str();
372+
mutateFunction(
373+
F,
374+
[=](CallInst *CI, std::vector<Value *> &Args) {
375+
auto *CompPtrTy = cast<PointerType>(CI->getOperand(1)->getType());
376+
auto *ET = CompPtrTy->getPointerElementType();
377+
Type *HalfTy = cast<StructType>(ET)->getElementType(0);
378+
IRBuilder<> Builder(CI);
379+
auto *Target = Builder.CreateStructGEP(ET, CI->getOperand(1), 0);
380+
Args[1] = Builder.CreateLoad(HalfTy, Target);
381+
return Name;
382+
},
383+
nullptr, &Attrs, true);
384+
}
385+
386+
void SPIRVRegularizeLLVMBase::expandSYCLHalfUsing(Module *M) {
387+
std::vector<Function *> ToExpandVEDWithSYCLHalfSRetArg;
388+
std::vector<Function *> ToExpandVIDWithSYCLHalfByValComp;
389+
390+
for (auto &F : *M) {
391+
if (F.getName().startswith("_Z28__spirv_VectorExtractDynamic") &&
392+
F.hasStructRetAttr()) {
393+
auto *SRetPtrTy = cast<PointerType>(F.getArg(0)->getType());
394+
if (isSYCLHalfType(SRetPtrTy->getPointerElementType()))
395+
ToExpandVEDWithSYCLHalfSRetArg.push_back(&F);
396+
else
397+
llvm_unreachable("The return type of the VectorExtractDynamic "
398+
"instruction cannot be a structure other than SYCL "
399+
"half.");
400+
}
401+
if (F.getName().startswith("_Z27__spirv_VectorInsertDynamic") &&
402+
F.getArg(1)->getType()->isPointerTy()) {
403+
auto *CompPtrTy = cast<PointerType>(F.getArg(1)->getType());
404+
auto *ET = CompPtrTy->getPointerElementType();
405+
if (isSYCLHalfType(ET))
406+
ToExpandVIDWithSYCLHalfByValComp.push_back(&F);
407+
else
408+
llvm_unreachable("The component argument type of an "
409+
"VectorInsertDynamic instruction can't be a "
410+
"structure other than SYCL half.");
411+
}
412+
}
413+
414+
for (auto *F : ToExpandVEDWithSYCLHalfSRetArg)
415+
expandVEDWithSYCLHalfSRetArg(F);
416+
for (auto *F : ToExpandVIDWithSYCLHalfByValComp)
417+
expandVIDWithSYCLHalfByValComp(F);
418+
}
419+
329420
void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
330421
if (!ST->hasName())
331422
return;
@@ -440,6 +531,7 @@ bool SPIRVRegularizeLLVMBase::regularize() {
440531
eraseUselessFunctions(M);
441532
lowerFuncPtr(M);
442533
addKernelEntryPoint(M);
534+
expandSYCLHalfUsing(M);
443535

444536
for (auto I = M->begin(), E = M->end(); I != E;) {
445537
Function *F = &(*I++);

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,21 @@ bool isSPIRVType(llvm::Type *Ty, StringRef BaseTyName, StringRef *Postfix) {
310310
return false;
311311
}
312312

313+
bool isSYCLHalfType(llvm::Type *Ty) {
314+
if (auto *ST = dyn_cast<StructType>(Ty)) {
315+
if (!ST->hasName())
316+
return false;
317+
StringRef Name = ST->getName();
318+
Name.consume_front("class.");
319+
if ((Name.startswith("cl::sycl::") ||
320+
Name.startswith("__sycl_internal::")) &&
321+
Name.endswith("::half")) {
322+
return true;
323+
}
324+
}
325+
return false;
326+
}
327+
313328
Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
314329
StringRef Name, BuiltinFuncMangleInfo *Mangle,
315330
AttributeList *Attrs, bool TakeName) {
@@ -344,6 +359,8 @@ Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
344359
}
345360
LLVM_DEBUG(dbgs() << "[getOrCreateFunction] ";
346361
if (F) dbgs() << *F << " => "; dbgs() << *NewF << '\n';);
362+
if (F)
363+
NewF->setDSOLocal(F->isDSOLocal());
347364
F = NewF;
348365
F->setCallingConv(CallingConv::SPIR_FUNC);
349366
if (Attrs)
@@ -703,6 +720,21 @@ void mutateFunction(
703720
F->eraseFromParent();
704721
}
705722

723+
void mutateFunction(
724+
Function *F,
725+
std::function<std::string(CallInst *, std::vector<Value *> &, Type *&RetTy)>
726+
ArgMutate,
727+
std::function<Instruction *(CallInst *)> RetMutate,
728+
BuiltinFuncMangleInfo *Mangle, AttributeList *Attrs, bool TakeName) {
729+
auto *M = F->getParent();
730+
for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
731+
if (auto *CI = dyn_cast<CallInst>(*I++))
732+
mutateCallInst(M, CI, ArgMutate, RetMutate, Mangle, Attrs, TakeName);
733+
}
734+
if (F->use_empty())
735+
F->eraseFromParent();
736+
}
737+
706738
CallInst *mutateCallInstSPIRV(
707739
Module *M, CallInst *CI,
708740
std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,

0 commit comments

Comments
 (0)