Skip to content

[AArch64][SME] Split SMECallAttrs out of SMEAttrs #137239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8653,6 +8653,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
}
}

static SMECallAttrs
getSMECallAttrs(const Function &Function,
const TargetLowering::CallLoweringInfo &CLI) {
if (CLI.CB)
return SMECallAttrs(*CLI.CB);
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
}

bool AArch64TargetLowering::isEligibleForTailCallOptimization(
const CallLoweringInfo &CLI) const {
CallingConv::ID CalleeCC = CLI.CallConv;
Expand All @@ -8671,12 +8681,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(

// SME Streaming functions are not eligible for TCO as they may require
// the streaming mode or ZA to be restored after returning from the call.
SMEAttrs CallerAttrs(MF.getFunction());
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingAllZAState() ||
CallAttrs.caller().hasStreamingBody())
return false;

// Functions using the C or Fast calling convention that have an SVE signature
Expand Down Expand Up @@ -8968,14 +8976,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
return TLI.LowerCallTo(CLI).second;
}

static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
const SMEAttrs &CalleeAttrs) {
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
CallerAttrs.hasStreamingBody())
static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
CallAttrs.caller().hasStreamingBody())
return AArch64SME::Always;
if (CalleeAttrs.hasNonStreamingInterface())
if (CallAttrs.callee().hasNonStreamingInterface())
return AArch64SME::IfCallerIsStreaming;
if (CalleeAttrs.hasStreamingInterface())
if (CallAttrs.callee().hasStreamingInterface())
return AArch64SME::IfCallerIsNonStreaming;

llvm_unreachable("Unsupported attributes");
Expand Down Expand Up @@ -9108,11 +9115,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}

// Determine whether we need any streaming mode changes.
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
if (CLI.CB)
CalleeAttrs = SMEAttrs(*CLI.CB);
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
CalleeAttrs = SMEAttrs(ES->getSymbol());
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);

auto DescribeCallsite =
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
Expand All @@ -9127,9 +9130,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return R;
};

bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
bool RequiresSaveAllZA =
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
bool RequiresLazySave = CallAttrs.requiresLazySave();
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
Expand Down Expand Up @@ -9157,18 +9159,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
assert(!CalleeAttrs.hasSharedZAInterface() &&
assert(!CallAttrs.callee().hasSharedZAInterface() &&
"Cannot share state that may not exist");
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/true);
}

SDValue PStateSM;
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
bool RequiresSMChange = CallAttrs.requiresSMChange();
if (RequiresSMChange) {
if (CallerAttrs.hasStreamingInterfaceOrBody())
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
PStateSM = DAG.getConstant(1, DL, MVT::i64);
else if (CallerAttrs.hasNonStreamingInterface())
else if (CallAttrs.caller().hasNonStreamingInterface())
PStateSM = DAG.getConstant(0, DL, MVT::i64);
else
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
Expand All @@ -9185,7 +9187,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

SDValue ZTFrameIdx;
MachineFrameInfo &MFI = MF.getFrameInfo();
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();

// If the caller has ZT0 state which will not be preserved by the callee,
// spill ZT0 before the call.
Expand All @@ -9201,7 +9203,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

// If caller shares ZT0 but the callee is not shared ZA, we need to stop
// PSTATE.ZA before the call if there is no lazy-save active.
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
assert((!DisableZA || !RequiresLazySave) &&
"Lazy-save should have PSTATE.SM=1 on entry to the function");

Expand Down Expand Up @@ -9483,9 +9485,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
InGlue = Chain.getValue(1);
}

SDValue NewChain = changeStreamingMode(
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
SDValue NewChain =
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
Expand Down Expand Up @@ -9664,8 +9666,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
Result = changeStreamingMode(
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
getSMCondition(CallAttrs), PStateSM);

if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
InGlue = Result.getValue(1);
Expand All @@ -9675,7 +9677,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
}

if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
if (CallAttrs.requiresEnablingZAAfterCall())
// Unconditionally resume ZA.
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
Expand Down Expand Up @@ -28552,12 +28554,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {

// Checks to allow the use of SME instructions
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
auto CallAttrs = SMECallAttrs(*Base);
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState())
return true;
}
return false;
Expand Down
25 changes: 13 additions & 12 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
SMECallAttrs CallAttrs(*Caller, *Callee);

// When inlining, we should consider the body of the function, not the
// interface.
if (CalleeAttrs.hasStreamingBody()) {
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
if (CallAttrs.callee().hasStreamingBody()) {
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
}

if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
return false;

if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
CallAttrs.requiresPreservingZT0() ||
CallAttrs.requiresPreservingAllZAState()) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
Expand Down Expand Up @@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
// streaming-mode change, and the call to G from F would also require a
// streaming-mode change, then there is benefit to do the streaming-mode
// change only once and avoid inlining of G into F.

SMEAttrs FAttrs(*F);
SMEAttrs CalleeAttrs(Call);
if (FAttrs.requiresSMChange(CalleeAttrs)) {
SMECallAttrs CallAttrs(Call);

if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
if (F == Call.getCaller()) // (1)
return CallPenaltyChangeSM * DefaultCallPenalty;
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
}

Expand Down
64 changes: 35 additions & 29 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"ZA_New and SME_ABI_Routine are mutually exclusive");

assert(
(!sharesZA() ||
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");

// ZT0 Attrs
assert(
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
isPreservesZT0())) &&
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
1 &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");

Expand All @@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"interface");
}

SMEAttrs::SMEAttrs(const CallBase &CB) {
*this = SMEAttrs(CB.getAttributes());
if (auto *F = CB.getCalledFunction()) {
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
}
}

SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
SMEAttrs::SME_ABI_Routine;
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
FuncName == "__arm_sme_state_size")
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask = 0;
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
Expand Down Expand Up @@ -99,17 +77,45 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= encodeZT0State(StateValue::New);
}

bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
if (Callee.hasStreamingCompatibleInterface())
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
unsigned KnownAttrs = SMEAttrs::Normal;
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
if (FuncName == "__arm_tpidr2_restore")
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
SMEAttrs::SME_ABI_Routine;
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
KnownAttrs |= SMEAttrs::SM_Compatible;
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
FuncName == "__arm_sme_state_size")
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
set(KnownAttrs, /*Enable=*/true);
}

bool SMECallAttrs::requiresSMChange() const {
if (callee().hasStreamingCompatibleInterface())
return false;

// Both non-streaming
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
if (caller().hasNonStreamingInterfaceAndBody() &&
callee().hasNonStreamingInterface())
return false;

// Both streaming
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
if (caller().hasStreamingInterfaceOrBody() &&
callee().hasStreamingInterface())
return false;

return true;
}

SMECallAttrs::SMECallAttrs(const CallBase &CB)
: CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
// FIXME: We probably should not allow SME attributes on direct calls but
// clang duplicates streaming mode attributes at each callsite.
assert((IsIndirect ||
((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
"SME attributes at callsite do not match declaration");
}
Loading
Loading