@@ -109,6 +109,20 @@ class SPIRVRegularizeLLVMBase {
109
109
void lowerUMulWithOverflow (IntrinsicInst *UMulIntrinsic);
110
110
void buildUMulWithOverflowFunc (Function *UMulFunc);
111
111
112
+ // For some cases Clang emits VectorExtractDynamic as:
113
+ // void @_Z28__spirv_VectorExtractDynamic(<Ty>* sret(<Ty>), jointMatrix, idx);
114
+ // Instead of:
115
+ // <Ty> @_Z28__spirv_VectorExtractDynamic(JointMatrix, Idx);
116
+ // And VectorInsertDynamic as:
117
+ // @_Z27__spirv_VectorInsertDynamic(jointMatrix, <Ty>* byval(<Ty>), idx);
118
+ // Instead of:
119
+ // @_Z27__spirv_VectorInsertDynamic(jointMatrix, <Ty>, idx)
120
+ // Need to add additional GEP, store and load instructions and mutate called
121
+ // function to avoid translation failures
122
+ void expandSYCLHalfUsing (Module *M);
123
+ void expandVEDWithSYCLHalfSRetArg (Function *F);
124
+ void expandVIDWithSYCLHalfByValComp (Function *F);
125
+
112
126
static std::string lowerLLVMIntrinsicName (IntrinsicInst *II);
113
127
void adaptStructTypes (StructType *ST);
114
128
static char ID;
@@ -326,6 +340,83 @@ void SPIRVRegularizeLLVMBase::lowerUMulWithOverflow(
326
340
UMulIntrinsic->setCalledFunction (UMulFunc);
327
341
}
328
342
343
+ void SPIRVRegularizeLLVMBase::expandVEDWithSYCLHalfSRetArg (Function *F) {
344
+ auto Attrs = F->getAttributes ();
345
+ Attrs = Attrs.removeParamAttribute (F->getContext (), 0 , Attribute::StructRet);
346
+ std::string Name = F->getName ().str ();
347
+ CallInst *OldCall = nullptr ;
348
+ mutateFunction (
349
+ F,
350
+ [=, &OldCall](CallInst *CI, std::vector<Value *> &Args, Type *&RetTy) {
351
+ Args.erase (Args.begin ());
352
+ auto *SRetPtrTy = cast<PointerType>(CI->getOperand (0 )->getType ());
353
+ auto *ET = SRetPtrTy->getPointerElementType ();
354
+ RetTy = cast<StructType>(ET)->getElementType (0 );
355
+ OldCall = CI;
356
+ return Name;
357
+ },
358
+ [=, &OldCall](CallInst *NewCI) {
359
+ IRBuilder<> Builder (OldCall);
360
+ auto *SRetPtrTy = cast<PointerType>(OldCall->getOperand (0 )->getType ());
361
+ auto *ET = SRetPtrTy->getPointerElementType ();
362
+ Value *Target = Builder.CreateStructGEP (ET, OldCall->getOperand (0 ), 0 );
363
+ return Builder.CreateStore (NewCI, Target);
364
+ },
365
+ nullptr , &Attrs, true );
366
+ }
367
+
368
+ void SPIRVRegularizeLLVMBase::expandVIDWithSYCLHalfByValComp (Function *F) {
369
+ auto Attrs = F->getAttributes ();
370
+ Attrs = Attrs.removeParamAttribute (F->getContext (), 1 , Attribute::ByVal);
371
+ std::string Name = F->getName ().str ();
372
+ mutateFunction (
373
+ F,
374
+ [=](CallInst *CI, std::vector<Value *> &Args) {
375
+ auto *CompPtrTy = cast<PointerType>(CI->getOperand (1 )->getType ());
376
+ auto *ET = CompPtrTy->getPointerElementType ();
377
+ Type *HalfTy = cast<StructType>(ET)->getElementType (0 );
378
+ IRBuilder<> Builder (CI);
379
+ auto *Target = Builder.CreateStructGEP (ET, CI->getOperand (1 ), 0 );
380
+ Args[1 ] = Builder.CreateLoad (HalfTy, Target);
381
+ return Name;
382
+ },
383
+ nullptr , &Attrs, true );
384
+ }
385
+
386
+ void SPIRVRegularizeLLVMBase::expandSYCLHalfUsing (Module *M) {
387
+ std::vector<Function *> ToExpandVEDWithSYCLHalfSRetArg;
388
+ std::vector<Function *> ToExpandVIDWithSYCLHalfByValComp;
389
+
390
+ for (auto &F : *M) {
391
+ if (F.getName ().startswith (" _Z28__spirv_VectorExtractDynamic" ) &&
392
+ F.hasStructRetAttr ()) {
393
+ auto *SRetPtrTy = cast<PointerType>(F.getArg (0 )->getType ());
394
+ if (isSYCLHalfType (SRetPtrTy->getPointerElementType ()))
395
+ ToExpandVEDWithSYCLHalfSRetArg.push_back (&F);
396
+ else
397
+ llvm_unreachable (" The return type of the VectorExtractDynamic "
398
+ " instruction cannot be a structure other than SYCL "
399
+ " half." );
400
+ }
401
+ if (F.getName ().startswith (" _Z27__spirv_VectorInsertDynamic" ) &&
402
+ F.getArg (1 )->getType ()->isPointerTy ()) {
403
+ auto *CompPtrTy = cast<PointerType>(F.getArg (1 )->getType ());
404
+ auto *ET = CompPtrTy->getPointerElementType ();
405
+ if (isSYCLHalfType (ET))
406
+ ToExpandVIDWithSYCLHalfByValComp.push_back (&F);
407
+ else
408
+ llvm_unreachable (" The component argument type of an "
409
+ " VectorInsertDynamic instruction can't be a "
410
+ " structure other than SYCL half." );
411
+ }
412
+ }
413
+
414
+ for (auto *F : ToExpandVEDWithSYCLHalfSRetArg)
415
+ expandVEDWithSYCLHalfSRetArg (F);
416
+ for (auto *F : ToExpandVIDWithSYCLHalfByValComp)
417
+ expandVIDWithSYCLHalfByValComp (F);
418
+ }
419
+
329
420
void SPIRVRegularizeLLVMBase::adaptStructTypes (StructType *ST) {
330
421
if (!ST->hasName ())
331
422
return ;
@@ -440,6 +531,7 @@ bool SPIRVRegularizeLLVMBase::regularize() {
440
531
eraseUselessFunctions (M);
441
532
lowerFuncPtr (M);
442
533
addKernelEntryPoint (M);
534
+ expandSYCLHalfUsing (M);
443
535
444
536
for (auto I = M->begin (), E = M->end (); I != E;) {
445
537
Function *F = &(*I++);
0 commit comments