Skip to content

[SPIR-V] Improve correctness of emitted MIR between passes for branching instructions #106966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
Expand Down Expand Up @@ -197,8 +195,6 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
Expand Down Expand Up @@ -391,8 +387,6 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
SpvScalConst =
getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull);

// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
GR.setCurrentFunc(MF);
for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
MachineBasicBlock *MBB = &*I;
SmallPtrSet<MachineInstr *, 8> ToMove;
for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
MBBI != MBBE;) {
MachineInstr &MI = *MBBI++;
Expand Down Expand Up @@ -456,8 +457,23 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
MI.removeOperand(i);
}
} break;
case SPIRV::OpPhi: {
// Phi refers to a type definition that goes after the Phi
// instruction, so that the virtual register definition of the type
// doesn't dominate all uses. Let's place the type definition
// instruction at the end of the predecessor.
MachineBasicBlock *Curr = MI.getParent();
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
if (Type->getParent() == Curr && !Curr->pred_empty())
ToMove.insert(const_cast<MachineInstr *>(Type));
} break;
}
}
for (MachineInstr *MI : ToMove) {
MachineBasicBlock *Curr = MI->getParent();
MachineBasicBlock *Pred = *Curr->pred_begin();
Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));
}
}
ProcessedMF.insert(&MF);
TargetLowering::finalizeLowering(MF);
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,11 @@ def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc,
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
"OpSelectionMerge $merge $sc">;
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
let isBarrier = 1, isTerminator=1 in {
def OpBranch: Op<249, (outs), (ins unknown:$label), "OpBranch $label">;
}
let isTerminator=1 in {
def OpBranch: Op<249, (outs), (ins ID:$label), "OpBranch $label">;
def OpBranchConditional: Op<250, (outs), (ins ID:$cond, ID:$true, ID:$false, variable_ops),
def OpBranchConditional: Op<250, (outs), (ins ID:$cond, unknown:$true, unknown:$false, variable_ops),
"OpBranchConditional $cond $true $false">;
def OpSwitch: Op<251, (outs), (ins ID:$sel, ID:$dflt, variable_ops), "OpSwitch $sel $dflt">;
}
Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,10 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
}

SmallPtrSet<MachineInstr *, 8> ToEraseMI;
SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
for (auto &SwIt : Switches) {
MachineInstr &MI = *SwIt.first;
MachineBasicBlock *MBB = MI.getParent();
SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
SmallVector<MachineOperand, 8> NewOps;
for (unsigned i = 0; i < Ins.size(); ++i) {
Expand All @@ -790,8 +792,11 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (It == BB2MBB.end())
report_fatal_error("cannot find a machine basic block by a basic "
"block in a switch statement");
NewOps.push_back(MachineOperand::CreateMBB(It->second));
MI.getParent()->addSuccessor(It->second);
MachineBasicBlock *Succ = It->second;
ClearAddressTaken.insert(Succ);
NewOps.push_back(MachineOperand::CreateMBB(Succ));
if (!llvm::is_contained(MBB->successors(), Succ))
MBB->addSuccessor(Succ);
ToEraseMI.insert(Ins[i]);
} else {
NewOps.push_back(
Expand Down Expand Up @@ -830,6 +835,12 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
}
BlockAddrI->eraseFromParent();
}

// BlockAddress operands were used to keep information between passes,
// let's undo the "address taken" status to reflect that Succ doesn't
// actually correspond to an IR-level basic block.
for (MachineBasicBlock *Succ : ClearAddressTaken)
Succ->setAddressTakenIRBlock(nullptr);
}

static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
Expand Down
6 changes: 5 additions & 1 deletion llvm/test/CodeGen/SPIRV/branching/OpSwitchBranches.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define i32 @test_switch_branches(i32 %a) {
entry:
Expand Down
6 changes: 5 additions & 1 deletion llvm/test/CodeGen/SPIRV/branching/OpSwitchEmpty.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
;; Command:
;; clang -cc1 -triple spir -emit-llvm -o test/SPIRV/OpSwitchEmpty.ll OpSwitchEmpty.cl -disable-llvm-passes

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV: %[[#X:]] = OpFunctionParameter %[[#]]
; CHECK-SPIRV: OpSwitch %[[#X]] %[[#DEFAULT:]]{{$}}
Expand Down
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define void @test_switch_with_unreachable_block(i1 %a) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define spir_kernel void @test_two_switch_same_register(i32 %value) {
; CHECK-SPIRV: OpSwitch %[[#REGISTER:]] %[[#DEFAULT1:]] 1 %[[#CASE1:]] 0 %[[#CASE2:]]
switch i32 %value, label %default1 [
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/SPIRV/transcoding/GlobalFunAnnotate.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-linux %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV: OpDecorate %[[#]] UserSemantic "annotation_on_function"
Expand Down
Loading