Skip to content

Commit 297166f

Browse files
[SYCL] [ESIMD] Remove one of the uses on __SYCL_EXPLICIT_SIMD__ (#3311)
This patch is a part of the efforts for allowing ESIMD and regular SYCL kernels to coexist in the same translation unit and in the same program. Previously, in ESIMD device code we had calls to SPIRV intrinsics that didn't have definitions. With the change in spirv_vars.hpp, SYCL optimization passes convert calls to SPIRV intrinsics into loads from globals (SPIRV builtins). Thus, there is a need to change the implementation of LowerESIMD pass to lower such new constructs. Example: // @__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 // ... // %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast // (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId // to <3 x i64> addrspace(4)*), align 32 // %1 = extractelement <3 x i64> %0, i64 0 // // => // // %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32() // %local_id.x = extractelement <3 x i32> %.esimd, i32 0 // %local_id.x.cast.ty = zext i32 %local_id.x to i64 Current tests in sycl/test/esimd/spirv_intrins_trans.cpp check that there is no regression in how we lower SPRIV intrinsics into GenX counterparts. But also, I added some more tests.
1 parent 2e9d33c commit 297166f

File tree

6 files changed

+290
-171
lines changed

6 files changed

+290
-171
lines changed

llvm/lib/SYCLLowerIR/LowerESIMD.cpp

Lines changed: 169 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ namespace {
7979
// /^_Z(\d+)__esimd_\w+/
8080
static constexpr char ESIMD_INTRIN_PREF0[] = "_Z";
8181
static constexpr char ESIMD_INTRIN_PREF1[] = "__esimd_";
82-
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_";
82+
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_BuiltIn";
8383

8484
static constexpr char GENX_KERNEL_METADATA[] = "genx.kernels";
8585

@@ -778,108 +778,122 @@ static int getIndexForSuffix(StringRef Suff) {
778778
.Default(-1);
779779
}
780780

781-
// Helper function to convert SPIRV intrinsic into GenX intrinsic,
782-
// that returns vector of coordinates.
783-
// Example:
784-
// %call = call spir_func i64 @_Z23__spirv_WorkgroupSize_xv()
785-
// =>
786-
// %call.esimd = tail call <3 x i32> @llvm.genx.local.size.v3i32()
787-
// %wgsize.x = extractelement <3 x i32> %call.esimd, i32 0
788-
// %wgsize.x.cast.ty = zext i32 %wgsize.x to i64
789-
static Instruction *generateVectorGenXForSpirv(CallInst &CI, StringRef Suff,
781+
// Helper function to convert extractelement instruction associated with the
782+
// load from SPIRV builtin global, into the GenX intrinsic that returns vector
783+
// of coordinates. It also generates required extractelement and cast
784+
// instructions. Example:
785+
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
786+
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
787+
// to <3 x i64> addrspace(4)*), align 32
788+
// %1 = extractelement <3 x i64> %0, i64 0
789+
//
790+
// =>
791+
//
792+
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
793+
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0
794+
// %local_id.x.cast.ty = zext i32 %local_id.x to i64
795+
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI,
796+
StringRef Suff,
790797
const std::string &IntrinName,
791798
StringRef ValueName) {
792799
std::string IntrName =
793800
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName;
794801
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
795-
LLVMContext &Ctx = CI.getModule()->getContext();
802+
LLVMContext &Ctx = EEI->getModule()->getContext();
796803
Type *I32Ty = Type::getInt32Ty(Ctx);
797804
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration(
798-
CI.getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
805+
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
799806
Instruction *IntrI =
800-
IntrinsicInst::Create(NewFDecl, {}, CI.getName() + ".esimd", &CI);
807+
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI);
801808
int ExtractIndex = getIndexForSuffix(Suff);
802809
assert(ExtractIndex != -1 && "Extract index is invalid.");
803810
Twine ExtractName = ValueName + Suff;
811+
804812
Instruction *ExtrI = ExtractElementInst::Create(
805-
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, &CI);
806-
Instruction *CastI = addCastInstIfNeeded(&CI, ExtrI);
813+
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI);
814+
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI);
807815
return CastI;
808816
}
809817

810-
// Helper function to convert SPIRV intrinsic into GenX intrinsic,
811-
// that has exact mapping.
812-
// Example:
813-
// %call = call spir_func i64 @_Z21__spirv_WorkgroupId_xv()
814-
// =>
815-
// %group.id.x = tail call i32 @llvm.genx.group.id.x()
816-
// %group.id.x.cast.ty = zext i32 %group.id.x to i64
817-
static Instruction *generateGenXForSpirv(CallInst &CI, StringRef Suff,
818+
// Helper function to convert extractelement instruction associated with the
819+
// load from SPIRV builtin global, into the GenX intrinsic. It also generates
820+
// required cast instructions. Example:
821+
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
822+
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
823+
// 32 %1 = extractelement <3 x i64> %0, i64 0
824+
// =>
825+
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
826+
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
827+
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
828+
// i32 %group.id.x to i64
829+
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI,
830+
StringRef Suff,
818831
const std::string &IntrinName) {
819832
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) +
820833
IntrinName + Suff.str();
821834
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
822835
Function *NewFDecl =
823-
GenXIntrinsic::getGenXDeclaration(CI.getModule(), ID, {});
836+
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {});
837+
824838
Instruction *IntrI =
825-
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), &CI);
826-
Instruction *CastI = addCastInstIfNeeded(&CI, IntrI);
839+
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI);
840+
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI);
827841
return CastI;
828842
}
829843

830-
// This function translates SPIRV intrinsic into GenX intrinsic.
831-
// TODO: Currently, we do not support mixing SYCL and ESIMD kernels.
832-
// Later for ESIMD and SYCL kernels to coexist, we likely need to
833-
// clone call graph that lead from ESIMD kernel to SPIRV intrinsic and
834-
// translate SPIRV intrinsics to GenX intrinsics only in cloned subgraph.
835-
static void
836-
translateSpirvIntrinsic(CallInst *CI, StringRef SpirvIntrName,
837-
SmallVector<Instruction *, 8> &ESIMDToErases) {
838-
auto translateSpirvIntr = [&SpirvIntrName, &ESIMDToErases,
839-
CI](StringRef SpvIName, auto TranslateFunc) {
840-
if (SpirvIntrName.consume_front(SpvIName)) {
841-
Value *TranslatedV = TranslateFunc(*CI, SpirvIntrName.substr(1, 1));
842-
CI->replaceAllUsesWith(TranslatedV);
843-
ESIMDToErases.push_back(CI);
844-
}
845-
};
844+
// This function translates one occurence of SPIRV builtin use into GenX
845+
// intrinsic.
846+
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI,
847+
StringRef SpirvGlobalName) {
848+
Value *IndexV = EEI->getIndexOperand();
849+
assert(isa<ConstantInt>(IndexV) &&
850+
"Extract element index should be a constant");
846851

847-
translateSpirvIntr("WorkgroupSize", [](CallInst &CI, StringRef Suff) {
848-
return generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
849-
});
850-
translateSpirvIntr("LocalInvocationId", [](CallInst &CI, StringRef Suff) {
851-
return generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id.");
852-
});
853-
translateSpirvIntr("WorkgroupId", [](CallInst &CI, StringRef Suff) {
854-
return generateGenXForSpirv(CI, Suff, "group.id.");
855-
});
856-
translateSpirvIntr("GlobalInvocationId", [](CallInst &CI, StringRef Suff) {
852+
// Get the suffix based on the index of extractelement instruction
853+
ConstantInt *IndexC = cast<ConstantInt>(IndexV);
854+
std::string Suff;
855+
if (IndexC->equalsInt(0))
856+
Suff = 'x';
857+
else if (IndexC->equalsInt(1))
858+
Suff = 'y';
859+
else if (IndexC->equalsInt(2))
860+
Suff = 'z';
861+
else
862+
assert(false && "Extract element index should be either 0, 1, or 2");
863+
864+
// Translate SPIRV into GenX intrinsic.
865+
if (SpirvGlobalName == "WorkgroupSize") {
866+
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
867+
} else if (SpirvGlobalName == "LocalInvocationId") {
868+
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
869+
} else if (SpirvGlobalName == "WorkgroupId") {
870+
return generateGenXForSpirv(EEI, Suff, "group.id.");
871+
} else if (SpirvGlobalName == "GlobalInvocationId") {
857872
// GlobalId = LocalId + WorkGroupSize * GroupId
858873
Instruction *LocalIdI =
859-
generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id.");
874+
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
860875
Instruction *WGSizeI =
861-
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
862-
Instruction *GroupIdI = generateGenXForSpirv(CI, Suff, "group.id.");
876+
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
877+
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id.");
863878
Instruction *MulI =
864-
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", &CI);
865-
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", &CI);
866-
});
867-
translateSpirvIntr("GlobalSize", [](CallInst &CI, StringRef Suff) {
879+
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
880+
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
881+
} else if (SpirvGlobalName == "GlobalSize") {
868882
// GlobalSize = WorkGroupSize * NumWorkGroups
869883
Instruction *WGSizeI =
870-
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
884+
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
871885
Instruction *NumWGI = generateVectorGenXForSpirv(
872-
CI, Suff, "group.count.v3i32", "group_count.");
873-
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", &CI);
874-
});
875-
// TODO: Support GlobalOffset SPIRV intrinsics
876-
translateSpirvIntr("GlobalOffset", [](CallInst &CI, StringRef Suff) {
877-
return llvm::Constant::getNullValue(CI.getType());
878-
});
879-
translateSpirvIntr("NumWorkgroups", [](CallInst &CI, StringRef Suff) {
880-
return generateVectorGenXForSpirv(CI, Suff, "group.count.v3i32",
886+
EEI, Suff, "group.count.v3i32", "group_count.");
887+
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
888+
} else if (SpirvGlobalName == "GlobalOffset") {
889+
// TODO: Support GlobalOffset SPIRV intrinsics
890+
return llvm::Constant::getNullValue(EEI->getType());
891+
} else if (SpirvGlobalName == "NumWorkgroups") {
892+
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32",
881893
"group_count.");
882-
});
894+
}
895+
896+
return nullptr;
883897
}
884898

885899
static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
@@ -1280,68 +1294,102 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
12801294

12811295
auto *CI = dyn_cast<CallInst>(&I);
12821296
Function *Callee = nullptr;
1283-
if (!CI || !(Callee = CI->getCalledFunction()))
1284-
continue;
1285-
StringRef Name = Callee->getName();
1297+
if (CI && (Callee = CI->getCalledFunction())) {
12861298

1287-
// See if the Name represents an ESIMD intrinsic and demangle only if it
1288-
// does.
1289-
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
1290-
continue;
1291-
// now skip the digits
1292-
Name = Name.drop_while([](char C) { return std::isdigit(C); });
1293-
1294-
// process ESIMD builtins that go through special handling instead of
1295-
// the translation procedure
1296-
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) {
1297-
// tag the kernel with meta-data SLMSize, and remove this builtin
1298-
translateSLMInit(*CI);
1299-
ESIMDToErases.push_back(CI);
1300-
continue;
1301-
}
1302-
if (Name.startswith("__esimd_pack_mask")) {
1303-
translatePackMask(*CI);
1304-
ESIMDToErases.push_back(CI);
1305-
continue;
1306-
}
1307-
if (Name.startswith("__esimd_unpack_mask")) {
1308-
translateUnPackMask(*CI);
1309-
ESIMDToErases.push_back(CI);
1310-
continue;
1311-
}
1312-
// If vload/vstore is not about the vector-types used by
1313-
// those globals marked as genx_volatile, We can translate
1314-
// them directly into generic load/store inst. In this way
1315-
// those insts can be optimized by llvm ASAP.
1316-
if (Name.startswith("__esimd_vload")) {
1317-
if (translateVLoad(*CI, GVTS)) {
1299+
StringRef Name = Callee->getName();
1300+
1301+
// See if the Name represents an ESIMD intrinsic and demangle only if it
1302+
// does.
1303+
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
1304+
continue;
1305+
// now skip the digits
1306+
Name = Name.drop_while([](char C) { return std::isdigit(C); });
1307+
1308+
// process ESIMD builtins that go through special handling instead of
1309+
// the translation procedure
1310+
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) {
1311+
// tag the kernel with meta-data SLMSize, and remove this builtin
1312+
translateSLMInit(*CI);
13181313
ESIMDToErases.push_back(CI);
13191314
continue;
13201315
}
1321-
}
1322-
if (Name.startswith("__esimd_vstore")) {
1323-
if (translateVStore(*CI, GVTS)) {
1316+
if (Name.startswith("__esimd_pack_mask")) {
1317+
translatePackMask(*CI);
13241318
ESIMDToErases.push_back(CI);
13251319
continue;
13261320
}
1327-
}
1321+
if (Name.startswith("__esimd_unpack_mask")) {
1322+
translateUnPackMask(*CI);
1323+
ESIMDToErases.push_back(CI);
1324+
continue;
1325+
}
1326+
// If vload/vstore is not about the vector-types used by
1327+
// those globals marked as genx_volatile, We can translate
1328+
// them directly into generic load/store inst. In this way
1329+
// those insts can be optimized by llvm ASAP.
1330+
if (Name.startswith("__esimd_vload")) {
1331+
if (translateVLoad(*CI, GVTS)) {
1332+
ESIMDToErases.push_back(CI);
1333+
continue;
1334+
}
1335+
}
1336+
if (Name.startswith("__esimd_vstore")) {
1337+
if (translateVStore(*CI, GVTS)) {
1338+
ESIMDToErases.push_back(CI);
1339+
continue;
1340+
}
1341+
}
13281342

1329-
if (Name.startswith("__esimd_get_value")) {
1330-
translateGetValue(*CI);
1331-
ESIMDToErases.push_back(CI);
1332-
continue;
1333-
}
1343+
if (Name.startswith("__esimd_get_value")) {
1344+
translateGetValue(*CI);
1345+
ESIMDToErases.push_back(CI);
1346+
continue;
1347+
}
13341348

1335-
if (Name.consume_front(SPIRV_INTRIN_PREF)) {
1336-
translateSpirvIntrinsic(CI, Name, ESIMDToErases);
1337-
// For now: if no match, just let it go untranslated.
1338-
continue;
1349+
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
1350+
continue;
1351+
// this is ESIMD intrinsic - record for later translation
1352+
ESIMDIntrCalls.push_back(CI);
13391353
}
13401354

1341-
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
1342-
continue;
1343-
// this is ESIMD intrinsic - record for later translation
1344-
ESIMDIntrCalls.push_back(CI);
1355+
// Translate loads from SPIRV builtin globals into GenX intrinsics
1356+
auto *LI = dyn_cast<LoadInst>(&I);
1357+
if (LI) {
1358+
Value *LoadPtrOp = LI->getPointerOperand();
1359+
Value *SpirvGlobal = nullptr;
1360+
// Look through casts to find SPIRV builtin globals
1361+
auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp);
1362+
if (CE) {
1363+
assert(CE->isCast() && "ConstExpr should be a cast");
1364+
SpirvGlobal = CE->getOperand(0);
1365+
} else {
1366+
SpirvGlobal = LoadPtrOp;
1367+
}
1368+
1369+
if (!isa<GlobalVariable>(SpirvGlobal) ||
1370+
!SpirvGlobal->getName().startswith(SPIRV_INTRIN_PREF))
1371+
continue;
1372+
1373+
auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size();
1374+
1375+
// Go through all the uses of the load instruction from SPIRV builtin
1376+
// globals, which are required to be extractelement instructions.
1377+
// Translate each of them.
1378+
for (auto *LU : LI->users()) {
1379+
auto *EEI = dyn_cast<ExtractElementInst>(LU);
1380+
assert(EEI && "User of load from global SPIRV builtin is not an "
1381+
"extractelement instruction");
1382+
Value *TranslatedVal = translateSpirvGlobalUse(
1383+
EEI, SpirvGlobal->getName().drop_front(PrefLen));
1384+
assert(TranslatedVal &&
1385+
"Load from global SPIRV builtin was not translated");
1386+
EEI->replaceAllUsesWith(TranslatedVal);
1387+
ESIMDToErases.push_back(EEI);
1388+
}
1389+
// After all users of load were translated, we get rid of the load
1390+
// itself.
1391+
ESIMDToErases.push_back(LI);
1392+
}
13451393
}
13461394
// Now demangle and translate found ESIMD intrinsic calls
13471395
for (auto *CI : ESIMDIntrCalls) {

llvm/test/SYCLLowerIR/esimd_lower_intrins.ll

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,16 +172,6 @@ define dso_local spir_kernel void @FUNC_30() {
172172
; CHECK-NEXT: ret void
173173
}
174174

175-
define dso_local spir_kernel void @FUNC_31() {
176-
; CHECK: define dso_local spir_kernel void @FUNC_31()
177-
%call = call spir_func i64 @_Z27__spirv_LocalInvocationId_xv()
178-
; CHECK-NEXT: %call.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
179-
; CHECK-NEXT: %local_id.x = extractelement <3 x i32> %call.esimd, i32 0
180-
; CHECK-NEXT: %local_id.x.cast.ty = zext i32 %local_id.x to i64
181-
ret void
182-
; CHECK-NEXT: ret void
183-
}
184-
185175
define dso_local spir_func <16 x i32> @FUNC_32() {
186176
%a_1 = alloca <16 x i32>
187177
%1 = load <16 x i32>, <16 x i32>* %a_1
@@ -318,7 +308,6 @@ define dso_local spir_func <16 x i32> @FUNC_44() {
318308
ret <16 x i32> %ret_val
319309
}
320310

321-
declare dso_local spir_func i64 @_Z27__spirv_LocalInvocationId_xv()
322311
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic0ILN2cm3gen14CmAtomicOpTypeE2EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeENS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i16> %1)
323312
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic1ILN2cm3gen14CmAtomicOpTypeE0EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeES7_NS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i32> %1, <32 x i16> %2)
324313
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic2ILN2cm3gen14CmAtomicOpTypeE7EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeES7_S7_NS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i32> %1, <32 x i32> %2, <32 x i16> %3)

0 commit comments

Comments
 (0)