Skip to content

[SYCL][ESIMD] Fix the crash in sycl-post-link while processing global spirv functions. #7590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
#include "llvm/SYCLLowerIR/SYCLUtils.h"

#include "../../IR/ConstantsContext.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -1142,9 +1143,8 @@ static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
/// right before the given extract element instruction \p EEI using the result
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
static Instruction *generateGenXCall(ExtractElementInst *EEI,
StringRef IntrinName, bool IsVectorCall) {
uint64_t IndexValue = getIndexFromExtract(EEI);
static Instruction *generateGenXCall(Instruction *EEI, StringRef IntrinName,
bool IsVectorCall, uint64_t IndexValue) {
std::string Suffix =
IsVectorCall
? ".v3i32"
Expand Down Expand Up @@ -1244,39 +1244,58 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
}

// Only loads from _vector_ SPIRV globals reach here now. Their users are
// expected to be ExtractElementInst only, and they are replaced in this loop.
// When loads from _scalar_ SPIRV globals are handled here as well, the users
// will not be replaced by new instructions, but the GenX call replacing the
// original load 'LI' should be inserted before each user.
// expected to be ExtractElementInst or TruncInst only, and they are replaced
// in this loop. When loads from _scalar_ SPIRV globals are handled here as
// well, the users will not be replaced by new instructions, but the GenX call
// replacing the original load 'LI' should be inserted before each user.
for (User *LU : LI->users()) {
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
assert(
(isa<ExtractElementInst>(LU) || isa<TruncInst>(LU)) &&
"SPIRV global users should be either ExtractElementInst or TruncInst");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being late for the review, I don't see such restrictions in SPIR-V spec, could you please help me finding it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These restrictions are to protect the code to make sure it processes only the code it is expected to process and not to enforce SPIR-V spec.

Instruction *EEI = cast<Instruction>(LU);
NewInst = nullptr;

uint64_t IndexValue = 0;
if (isa<ExtractElementInst>(EEI)) {
IndexValue = getIndexFromExtract(cast<ExtractElementInst>(EEI));
} else {
auto *GEPCE = cast<GetElementPtrConstantExpr>(LI->getPointerOperand());

IndexValue = cast<Constant>(GEPCE->getOperand(2))
->getUniqueInteger()
.getZExtValue();
}

if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true);
NewInst = generateGenXCall(EEI, "local.size", true, IndexValue);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true);
NewInst = generateGenXCall(EEI, "local.id", true, IndexValue);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false);
NewInst = generateGenXCall(EEI, "group.id", false, IndexValue);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true);
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false);
Instruction *LocalIdI =
generateGenXCall(EEI, "local.id", true, IndexValue);
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *GroupIdI =
generateGenXCall(EEI, "group.id", false, IndexValue);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *NumWGI = generateGenXCall(EEI, "group.count", true);
Instruction *WGSizeI =
generateGenXCall(EEI, "local.size", true, IndexValue);
Instruction *NumWGI =
generateGenXCall(EEI, "group.count", true, IndexValue);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true);
NewInst = generateGenXCall(EEI, "group.count", true, IndexValue);
}

llvm::esimd::assert_and_diag(
Expand Down Expand Up @@ -1786,9 +1805,13 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
if (LI) {
Value *LoadPtrOp = LI->getPointerOperand();
Value *SpirvGlobal = nullptr;
// Look through casts to find SPIRV builtin globals
// Look through constant expressions to find SPIRV builtin globals
// It may come with or without cast.
auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp);
if (CE) {
auto *GEPCE = dyn_cast<GetElementPtrConstantExpr>(LoadPtrOp);
if (GEPCE) {
SpirvGlobal = GEPCE->getOperand(0);
} else if (CE) {
assert(CE->isCast() && "ConstExpr should be a cast");
SpirvGlobal = CE->getOperand(0);
} else {
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/tools/sycl-post-link/sycl-post-link-test.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: sycl-post-link -split-esimd -lower-esimd -O2 -S %s -o %t.table
; RUN: FileCheck %s -input-file=%t_esimd_0.ll
; This test checks that IR code below can be successfully processed by
; sycl-post-link. In this IR no extractelement instruction and no casting are used

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: convergent norecurse
define dso_local spir_kernel void @ESIMD_kernel() #0 !sycl_explicit_simd !3 {
entry:
%0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
%conv = trunc i64 %0 to i32
ret void
}

attributes #0 = { "sycl-module-id"="a.cpp" }

!llvm.module.flags = !{!0}
!opencl.spir.version = !{!1}
!spirv.Source = !{!2}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 2}
!2 = !{i32 0, i32 100000}
!3 = !{}

; CHECK: define dso_local spir_kernel void @ESIMD_kernel()
; CHECK: call <3 x i32> @llvm.genx.local.id.v3i32()
; CHECK: call <3 x i32> @llvm.genx.local.size.v3i32()
; CHECK: call i32 @llvm.genx.group.id.x()
; CHECK: ret void
; CHECK: }