Skip to content

[AMDGPU] Filter candidates of LiveRegOptimizer for profitable cases #124624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 130 additions & 1 deletion llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/Utils/Local.h"
Expand Down Expand Up @@ -75,6 +76,7 @@ class LiveRegOptimizer {
Module &Mod;
const DataLayout &DL;
const GCNSubtarget &ST;

/// The scalar type to convert to
Type *const ConvertToScalar;
/// The set of visited Instructions
Expand Down Expand Up @@ -125,6 +127,131 @@ class LiveRegOptimizer {
return LK.first != TargetLoweringBase::TypeLegal;
}

/// Check if intrinsic natively operates on 8-bit or 16-bit
bool isNativeIntrinsic(Intrinsic::ID ID) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need all the dot / mfma / wmma variants -- just the 8 and 4 bit ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isNative intrinsic is redundant; we also really don't want to maintain this table. Is any target intrinsic good enough? If we really must, should have use SearchableTables

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is any target intrinsic good enough

It's probably "good enough" -- it may result in false positive (i.e. applying the coercion to cases where the is no profit), but maybe not many.
a
What we really want is a way to check if an intrinsic / instruction can natively handle illegal vector types without compiler assistance to scalarize the code -- this query may be useful for future work too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiler assistance to scalarize the code

Not clear to me what this means. All the vector types will be split or scalarized in some way. The target intrinsics, with a handful of exceptions, only operate on legal types. isTargetIntrinsic || isTypeLegal()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear to me what this means

"intrinsic / instruction can natively handle illegal vector types" = The target intrinsic has either bytewise logic on the 32 bit regs (like mfma) or type doesn't matter (i.e. store).

For a standard instruction (e.g. v4i8 add), the compiler will need to insert scalarization code.

We should probably just use isTargetIntrinsic and narrow it as needed.

switch (ID) {
case Intrinsic::amdgcn_dot4_f32_fp8_bf8:
case Intrinsic::amdgcn_dot4_f32_bf8_fp8:
case Intrinsic::amdgcn_dot4_f32_fp8_fp8:
case Intrinsic::amdgcn_dot4_f32_bf8_bf8:
case Intrinsic::amdgcn_mfma_i32_4x4x4i8:
case Intrinsic::amdgcn_mfma_i32_16x16x4i8:
case Intrinsic::amdgcn_mfma_i32_32x32x4i8:
case Intrinsic::amdgcn_mfma_i32_16x16x16i8:
case Intrinsic::amdgcn_mfma_i32_32x32x8i8:
case Intrinsic::amdgcn_mfma_i32_16x16x64_i8:
case Intrinsic::amdgcn_mfma_i32_32x32x32_i8:
case Intrinsic::amdgcn_mfma_i32_32x32x16_i8:
case Intrinsic::amdgcn_mfma_i32_16x16x32_i8:
case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_bf8:
case Intrinsic::amdgcn_mfma_f32_16x16x32_bf8_fp8:
case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_bf8:
case Intrinsic::amdgcn_mfma_f32_16x16x32_fp8_fp8:
case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_bf8:
case Intrinsic::amdgcn_mfma_f32_32x32x16_bf8_fp8:
case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_bf8:
case Intrinsic::amdgcn_mfma_f32_32x32x16_fp8_fp8:
case Intrinsic::amdgcn_smfmac_i32_16x16x64_i8:
case Intrinsic::amdgcn_smfmac_i32_32x32x32_i8:
case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_bf8:
case Intrinsic::amdgcn_smfmac_f32_16x16x64_bf8_fp8:
case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_bf8:
case Intrinsic::amdgcn_smfmac_f32_16x16x64_fp8_fp8:
case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_bf8:
case Intrinsic::amdgcn_smfmac_f32_32x32x32_bf8_fp8:
case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_bf8:
case Intrinsic::amdgcn_smfmac_f32_32x32x32_fp8_fp8:
case Intrinsic::amdgcn_smfmac_i32_16x16x128_i8:
case Intrinsic::amdgcn_smfmac_i32_32x32x64_i8:
case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_bf8:
case Intrinsic::amdgcn_smfmac_f32_16x16x128_bf8_fp8:
case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_bf8:
case Intrinsic::amdgcn_smfmac_f32_16x16x128_fp8_fp8:
case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_bf8:
case Intrinsic::amdgcn_smfmac_f32_32x32x64_bf8_fp8:
case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_bf8:
case Intrinsic::amdgcn_smfmac_f32_32x32x64_fp8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8:
case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8:
case Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8:
case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8:
case Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8:
case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8:
case Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8:
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu8:
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu4:
case Intrinsic::amdgcn_wmma_i32_16x16x32_iu4:
case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8:
case Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4:
case Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add the buffer_store intrinsics

case Intrinsic::amdgcn_raw_buffer_store_format:
case Intrinsic::amdgcn_raw_buffer_store:
case Intrinsic::amdgcn_raw_ptr_buffer_store_format:
case Intrinsic::amdgcn_raw_ptr_buffer_store:
case Intrinsic::amdgcn_struct_buffer_store_format:
case Intrinsic::amdgcn_struct_buffer_store:
case Intrinsic::amdgcn_struct_ptr_buffer_store_format:
case Intrinsic::amdgcn_struct_ptr_buffer_store:
case Intrinsic::amdgcn_raw_tbuffer_store:
case Intrinsic::amdgcn_raw_ptr_tbuffer_store:
case Intrinsic::amdgcn_struct_ptr_tbuffer_store:
case Intrinsic::amdgcn_struct_tbuffer_store:
return true;
default:
return false;
}
}

bool isOpLegal(Instruction *I) {
if (const auto *Intr = dyn_cast<IntrinsicInst>(I)) {
Intrinsic::ID ID = Intr->getIntrinsicID();
if (isNativeIntrinsic(ID))
return true;
}
// Stores
if (isa<StoreInst>(I))
return true;
return false;
}

bool isCoercionProfitable(Instruction *II) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this:

  bool isCoercionProfitable(Instruction *II) {
    SmallPtrSet<Instruction *, 4> CVisited;
    SmallVector<Instruction *, 4> UserList;

    // Check users for profitable conditions (across block user which can natively
    // handle the illegal vector).
    for (User *V : II->users())
      if (auto *UseInst = dyn_cast<Instruction>(V))
        UserList.push_back(UseInst);

    auto IsLookThru = [](Instruction *II) {
      return isa<PHINode>(II) || isa<ShuffleVectorInst>(II) ||
          isa<InsertElementInst>(II) || isa<ExtractElementInst>(II) || isa<CastInst>(II);
    };

    while (!UserList.empty()) {
      auto CII = UserList.pop_back_val();
      if (!CVisited.insert(CII).second)
        continue;    

      if (CII->getParent() == II->getParent() && !IsLookThru(II))
        continue;
      
      if (isOpLegal(CII))
        return true;

      if (IsLookThru(CII))
        for (User *V : CII->users())
          if (auto *UseInst = dyn_cast<Instruction>(V))
            UserList.push_back(UseInst);
    }
    return false;
  }

SmallPtrSet<Instruction *, 4> CVisited;
SmallVector<Instruction *, 4> UserList;

// Check users for profitable conditions (across block user which can
// natively handle the illegal vector).
for (User *V : II->users())
if (auto *UseInst = dyn_cast<Instruction>(V))
UserList.push_back(UseInst);

auto IsLookThru = [](Instruction *II) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also look thru v_perm intrinsic

if (const auto *Intr = dyn_cast<IntrinsicInst>(II))
return Intr->getIntrinsicID() == Intrinsic::amdgcn_perm;
return isa<PHINode>(II) || isa<ShuffleVectorInst>(II) ||
isa<InsertElementInst>(II) || isa<ExtractElementInst>(II) ||
isa<CastInst>(II);
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this heuristic isn't covered by tests, there are no new tests added

while (!UserList.empty()) {
auto CII = UserList.pop_back_val();
if (!CVisited.insert(CII).second)
continue;

if (CII->getParent() == II->getParent() && !IsLookThru(II))
continue;

if (isOpLegal(CII))
return true;

if (IsLookThru(CII))
for (User *V : CII->users())
if (auto *UseInst = dyn_cast<Instruction>(V))
UserList.push_back(UseInst);
}
return false;
}

LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST)
: Mod(Mod), DL(Mod.getDataLayout()), ST(ST),
ConvertToScalar(Type::getInt32Ty(Mod.getContext())) {}
Expand Down Expand Up @@ -259,6 +386,9 @@ bool LiveRegOptimizer::optimizeLiveType(
if (!shouldReplace(II->getType()))
continue;

if (!isCoercionProfitable(II))
continue;

if (PHINode *Phi = dyn_cast<PHINode>(II)) {
PhiNodes.insert(Phi);
// Collect all the incoming values of problematic PHI nodes.
Expand Down Expand Up @@ -478,7 +608,6 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
PreservedAnalyses
AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) {
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);

AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F);
UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F);

Expand Down
Loading