Skip to content

Commit 63c749c

Browse files
authored
[SYCL][ESIMD] Fix the crash in sycl-post-link while processing global spirv functions. (#7590)
Currently for ` int i = __spirv_GlobalInvocationId_x();` c++ code followinf IR code is generated: ``` %0 = load <3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, align 32 %1 = extractelement <3 x i64> %0, i64 0 %conv = trunc i64 %1 to i32 ``` This IR is what sycl-post-link is expecting and is able to process successfully. However, following IR code was generated during the testing: ``` %0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32 %conv = trunc i64 %0 to i32 ``` and it caused sycl-post-link to crash. The fix resolves the issue
1 parent 9297f63 commit 63c749c

File tree

2 files changed

+77
-19
lines changed

2 files changed

+77
-19
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

+42-19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
1818
#include "llvm/SYCLLowerIR/SYCLUtils.h"
1919

20+
#include "../../IR/ConstantsContext.h"
2021
#include "llvm/ADT/DenseMap.h"
2122
#include "llvm/ADT/DenseSet.h"
2223
#include "llvm/ADT/SmallVector.h"
@@ -1142,9 +1143,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
11421143
/// right before the given extract element instruction \p EEI using the result
11431144
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
11441145
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
1145-
static Instruction *generateGenXCall(ExtractElementInst *EEI,
1146-
StringRef IntrinName, bool IsVectorCall) {
1147-
uint64_t IndexValue = getIndexFromExtract(EEI);
1146+
static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
1147+
bool IsVectorCall, uint64_t IndexValue) {
11481148
std::string Suffix =
11491149
IsVectorCall
11501150
? ".v3i32"
@@ -1244,39 +1244,58 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
12441244
}
12451245

12461246
// Only loads from _vector_ SPIRV globals reach here now. Their users are
1247-
// expected to be ExtractElementInst only, and they are replaced in this loop.
1248-
// When loads from _scalar_ SPIRV globals are handled here as well, the users
1249-
// will not be replaced by new instructions, but the GenX call replacing the
1250-
// original load 'LI' should be inserted before each user.
1247+
// expected to be ExtractElementInst or TruncInst only, and they are replaced
1248+
// in this loop. When loads from _scalar_ SPIRV globals are handled here as
1249+
// well, the users will not be replaced by new instructions, but the GenX call
1250+
// replacing the original load 'LI' should be inserted before each user.
12511251
for (User *LU : LI->users()) {
1252-
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
1252+
assert(
1253+
(isa<ExtractElementInst>(LU) || isa<TruncInst>(LU)) &&
1254+
"SPIRV global users should be either ExtractElementInst or TruncInst");
1255+
Instruction *EEI = cast<Instruction>(LU);
12531256
NewInst = nullptr;
12541257

1258+
uint64_t IndexValue = 0;
1259+
if (isa<ExtractElementInst>(EEI)) {
1260+
IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));
1261+
} else {
1262+
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());
1263+
1264+
IndexValue = cast<Constant>(GEPCE->getOperand(2))
1265+
->getUniqueInteger()
1266+
.getZExtValue();
1267+
}
1268+
12551269
if (SpirvGlobalName == "WorkgroupSize") {
1256-
NewInst = generateGenXCall(EEI, "local.size", true);
1270+
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
12571271
} else if (SpirvGlobalName == "LocalInvocationId") {
1258-
NewInst = generateGenXCall(EEI, "local.id", true);
1272+
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
12591273
} else if (SpirvGlobalName == "WorkgroupId") {
1260-
NewInst = generateGenXCall(EEI, "group.id", false);
1274+
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
12611275
} else if (SpirvGlobalName == "GlobalInvocationId") {
12621276
// GlobalId = LocalId + WorkGroupSize * GroupId
1263-
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true);
1264-
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
1265-
Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false);
1277+
Instruction *LocalIdI =
1278+
generateGenXCall(EEI, "local.id", true, IndexValue);
1279+
Instruction *WGSizeI =
1280+
generateGenXCall(EEI, "local.size", true, IndexValue);
1281+
Instruction *GroupIdI =
1282+
generateGenXCall(EEI, "group.id", false, IndexValue);
12661283
Instruction *MulI =
12671284
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
12681285
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
12691286
} else if (SpirvGlobalName == "GlobalSize") {
12701287
// GlobalSize = WorkGroupSize * NumWorkGroups
1271-
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
1272-
Instruction *NumWGI = generateGenXCall(EEI, "group.count", true);
1288+
Instruction *WGSizeI =
1289+
generateGenXCall(EEI, "local.size", true, IndexValue);
1290+
Instruction *NumWGI =
1291+
generateGenXCall(EEI, "group.count", true, IndexValue);
12731292
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
12741293
} else if (SpirvGlobalName == "GlobalOffset") {
12751294
// TODO: Support GlobalOffset SPIRV intrinsics
12761295
// Currently all users of load of GlobalOffset are replaced with 0.
12771296
NewInst = llvm::Constant::getNullValue(EEI->getType());
12781297
} else if (SpirvGlobalName == "NumWorkgroups") {
1279-
NewInst = generateGenXCall(EEI, "group.count", true);
1298+
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
12801299
}
12811300

12821301
llvm::esimd::assert_and_diag(
@@ -1786,9 +1805,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
17861805
if (LI) {
17871806
Value *LoadPtrOp = LI->getPointerOperand();
17881807
Value *SpirvGlobal = nullptr;
1789-
// Look through casts to find SPIRV builtin globals
1808+
// Look through constant expressions to find SPIRV builtin globals
1809+
// It may come with or without cast.
17901810
auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp);
1791-
if (CE) {
1811+
auto *GEPCE = dyn_cast<GetElementPtrConstantExpr>(LoadPtrOp);
1812+
if (GEPCE) {
1813+
SpirvGlobal = GEPCE->getOperand(0);
1814+
} else if (CE) {
17921815
assert(CE->isCast() && "ConstExpr should be a cast");
17931816
SpirvGlobal = CE->getOperand(0);
17941817
} else {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: sycl-post-link -split-esimd -lower-esimd -O2 -S %s -o %t.table
2+
; RUN: FileCheck %s -input-file=%t_esimd_0.ll
3+
; This test checks that IR code below can be successfully processed by
4+
; sycl-post-link. In this IR no extractelement instruction and no casting are used
5+
6+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
7+
target triple = "spir64-unknown-unknown"
8+
9+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
10+
11+
; Function Attrs: convergent norecurse
12+
define dso_local spir_kernel void @ESIMD_kernel() #0 !sycl_explicit_simd !3 {
13+
entry:
14+
%0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
15+
%conv = trunc i64 %0 to i32
16+
ret void
17+
}
18+
19+
attributes #0 = { "sycl-module-id"="a.cpp" }
20+
21+
!llvm.module.flags = !{!0}
22+
!opencl.spir.version = !{!1}
23+
!spirv.Source = !{!2}
24+
25+
!0 = !{i32 1, !"wchar_size", i32 4}
26+
!1 = !{i32 1, i32 2}
27+
!2 = !{i32 0, i32 100000}
28+
!3 = !{}
29+
30+
; CHECK: define dso_local spir_kernel void @ESIMD_kernel()
31+
; CHECK: call <3 x i32> @llvm.genx.local.id.v3i32()
32+
; CHECK: call <3 x i32> @llvm.genx.local.size.v3i32()
33+
; CHECK: call i32 @llvm.genx.group.id.x()
34+
; CHECK: ret void
35+
; CHECK: }

0 commit comments

Comments
 (0)