Skip to content

[SPIR-V] Add SPIR-V structurizer #107408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions clang/test/CodeGenHLSL/convergence/cf.for.plain.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s

int process() {
// CHECK: entry:
// CHECK: %[[#entry_token:]] = call token @llvm.experimental.convergence.entry()
int val = 0;

// CHECK: for.cond:
// CHECK-NEXT: %[[#]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_token]]) ]
// CHECK: br i1 {{.*}}, label %for.body, label %for.end
for (int i = 0; i < 10; ++i) {

// CHECK: for.body:
// CHECK: br label %for.inc
val = i;

// CHECK: for.inc:
// CHECK: br label %for.cond
}

// CHECK: for.end:
// CHECK: br label %for.cond1

// Infinite loop
for ( ; ; ) {
// CHECK: for.cond1:
// CHECK-NEXT: %[[#]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_token]]) ]
// CHECK: br label %for.cond1
val = 0;
}

// CHECK-NEXT: }
// This loop in unreachable. Not generated.
// Null body
for (int j = 0; j < 10; ++j)
;
return val;
}

[numthreads(1, 1, 1)]
void main() {
process();
}
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ let TargetPrefix = "spv" in {
def int_spv_bitcast : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_unreachable : Intrinsic<[], []>;
def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ class ConvergenceRegionAnalyzer {

private:
bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
assert(From != To && "From == To. This is awkward.");
if (From == To)
return true;

// We only handle loop in the simplified form. This means:
// - a single back-edge, a single latch.
Expand All @@ -230,6 +231,7 @@ class ConvergenceRegionAnalyzer {
auto *Terminator = From->getTerminator();
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
auto *To = Terminator->getSuccessor(i);
// Ignore back edges.
if (isBackEdge(From, To))
continue;

Expand Down Expand Up @@ -276,7 +278,6 @@ class ConvergenceRegionAnalyzer {
while (ToProcess.size() != 0) {
auto *L = ToProcess.front();
ToProcess.pop();
assert(L->isLoopSimplifyForm());

auto CT = getConvergenceToken(L->getHeader());
SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class ConvergenceRegionInfo {
}

const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
ConvergenceRegion *getWritableTopLevelRegion() const {
return TopLevelRegion;
}
};

} // namespace SPIRV
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVMCInstLower.cpp
SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
SPIRVStructurizer.cpp
SPIRVPreLegalizer.cpp
SPIRVPostLegalizer.cpp
SPIRVPrepareFunctions.cpp
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class InstructionSelector;
class RegisterBankInfo;

ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVStructurizerPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -617,15 +617,13 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;

def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
"$res = OpPhi $type $var0 $block0">;
def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
def OpLoopMerge: Op<246, (outs), (ins unknown:$merge, unknown:$continue, LoopControl:$lc, variable_ops),
"OpLoopMerge $merge $continue $lc">;
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
def OpSelectionMerge: Op<247, (outs), (ins unknown:$merge, SelectionControl:$sc),
"OpSelectionMerge $merge $sc">;
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
let isBarrier = 1, isTerminator=1 in {
def OpBranch: Op<249, (outs), (ins unknown:$label), "OpBranch $label">;
}
let isTerminator=1 in {
def OpBranchConditional: Op<250, (outs), (ins ID:$cond, unknown:$true, unknown:$false, variable_ops),
"OpBranchConditional $cond $true $false">;
def OpSwitch: Op<251, (outs), (ins ID:$sel, ID:$dflt, variable_ops), "OpSwitch $sel $dflt">;
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,19 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
}
return MIB.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_loop_merge:
case Intrinsic::spv_selection_merge: {
const auto Opcode = IID == Intrinsic::spv_selection_merge
? SPIRV::OpSelectionMerge
: SPIRV::OpLoopMerge;
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
assert(I.getOperand(i).isMBB());
MIB.addMBB(I.getOperand(i).getMBB());
}
MIB.addImm(SPIRV::SelectionControl::None);
return MIB.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_cmpxchg:
return selectAtomicCmpXchg(ResVReg, ResType, I);
case Intrinsic::spv_unreachable:
Expand Down
20 changes: 12 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
// Run the pass on the given convergence region, ignoring the sub-regions.
// Returns true if the CFG changed, false otherwise.
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
SPIRV::ConvergenceRegion *CR) {
// Gather all the exit targets for this region.
SmallPtrSet<BasicBlock *, 4> ExitTargets;
for (BasicBlock *Exit : CR->Exits) {
Expand Down Expand Up @@ -198,14 +198,19 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
for (auto Exit : CR->Exits)
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);

CR = CR->Parent;
while (CR) {
CR->Blocks.insert(NewExitTarget);
CR = CR->Parent;
}

return true;
}

/// Run the pass on the given convergence region and sub-regions (DFS).
/// Returns true if a region/sub-region was modified, false otherwise.
/// This returns as soon as one region/sub-region has been modified.
bool runOnConvergenceRegion(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
for (auto *Child : CR->Children)
if (runOnConvergenceRegion(LI, Child))
return true;
Expand Down Expand Up @@ -235,20 +240,17 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {

virtual bool runOnFunction(Function &F) override {
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
const auto *TopLevelRegion =
auto *TopLevelRegion =
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
.getRegionInfo()
.getTopLevelRegion();
.getWritableTopLevelRegion();

// FIXME: very inefficient method: each time a region is modified, we bubble
// back up, and recompute the whole convergence region tree. Once the
// algorithm is completed and test coverage good enough, rewrite this pass
// to be efficient instead of simple.
bool modified = false;
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
.getRegionInfo()
.getTopLevelRegion();
modified = true;
}

Expand All @@ -262,6 +264,8 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();

AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
};
Expand Down
18 changes: 0 additions & 18 deletions llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,23 +175,6 @@ void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
visit(MF, *MF.begin(), op);
}

// Sorts basic blocks by dominance to respect the SPIR-V spec.
void sortBlocks(MachineFunction &MF) {
MachineDominatorTree MDT(MF);

std::unordered_map<MachineBasicBlock *, size_t> Order;
Order.reserve(MF.size());

size_t Index = 0;
visit(MF, [&Order, &Index](MachineBasicBlock *MBB) { Order[MBB] = Index++; });

auto Comparator = [&Order](MachineBasicBlock &LHS, MachineBasicBlock &RHS) {
return Order[&LHS] < Order[&RHS];
};

MF.sort(Comparator);
}

bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
// Initialize the type registry.
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
Expand All @@ -200,7 +183,6 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
MachineIRBuilder MIB(MF);

processNewInstrs(MF, GR, MIB);
sortBlocks(MF);

return true;
}
Expand Down
Loading