Skip to content

Commit 647db1b

Browse files
authored
Reland "[AArch64][SME] Split SMECallAttrs out of SMEAttrs" (#138671)
SMECallAttrs is a new helper class that holds all the SMEAttrs for a call. The interfaces to query actions needed for the call (e.g. change streaming mode) have been moved to the SMECallAttrs class. The main motivation for this change is to make the split between the caller, callee, and callsite attributes more apparent. Before this change, we would always merge callsite and callee attributes. The main reason to do this was to handle indirect calls, however, we also occasionally used callsite attributes on direct calls in tests (mainly to avoid creating multiple function declarations). With this patch, we now explicitly handle indirect calls and disallow incompatible attributes on direct calls (so this patch is not entirely an NFC). Same as #137239, but with a change to avoid inferring SME attributes for function definitions. This allows stubbing the SME ABI routines in C/C++ (and matches the old behaviour).
1 parent ab119ad commit 647db1b

File tree

9 files changed

+320
-210
lines changed

9 files changed

+320
-210
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+38-37
Original file line numberDiff line numberDiff line change
@@ -8641,6 +8641,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86418641
}
86428642
}
86438643

8644+
static SMECallAttrs
8645+
getSMECallAttrs(const Function &Caller,
8646+
const TargetLowering::CallLoweringInfo &CLI) {
8647+
if (CLI.CB)
8648+
return SMECallAttrs(*CLI.CB);
8649+
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8650+
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
8651+
return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
8652+
}
8653+
86448654
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86458655
const CallLoweringInfo &CLI) const {
86468656
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8659,12 +8669,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86598669

86608670
// SME Streaming functions are not eligible for TCO as they may require
86618671
// the streaming mode or ZA to be restored after returning from the call.
8662-
SMEAttrs CallerAttrs(MF.getFunction());
8663-
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8664-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8665-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8666-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8667-
CallerAttrs.hasStreamingBody())
8672+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8673+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8674+
CallAttrs.requiresPreservingAllZAState() ||
8675+
CallAttrs.caller().hasStreamingBody())
86688676
return false;
86698677

86708678
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8956,14 +8964,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89568964
return TLI.LowerCallTo(CLI).second;
89578965
}
89588966

8959-
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8960-
const SMEAttrs &CalleeAttrs) {
8961-
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8962-
CallerAttrs.hasStreamingBody())
8967+
static AArch64SME::ToggleCondition
8968+
getSMToggleCondition(const SMECallAttrs &CallAttrs) {
8969+
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8970+
CallAttrs.caller().hasStreamingBody())
89638971
return AArch64SME::Always;
8964-
if (CalleeAttrs.hasNonStreamingInterface())
8972+
if (CallAttrs.callee().hasNonStreamingInterface())
89658973
return AArch64SME::IfCallerIsStreaming;
8966-
if (CalleeAttrs.hasStreamingInterface())
8974+
if (CallAttrs.callee().hasStreamingInterface())
89678975
return AArch64SME::IfCallerIsNonStreaming;
89688976

89698977
llvm_unreachable("Unsupported attributes");
@@ -9096,11 +9104,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90969104
}
90979105

90989106
// Determine whether we need any streaming mode changes.
9099-
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9100-
if (CLI.CB)
9101-
CalleeAttrs = SMEAttrs(*CLI.CB);
9102-
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9103-
CalleeAttrs = SMEAttrs(ES->getSymbol());
9107+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
91049108

91059109
auto DescribeCallsite =
91069110
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9115,9 +9119,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91159119
return R;
91169120
};
91179121

9118-
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9119-
bool RequiresSaveAllZA =
9120-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9122+
bool RequiresLazySave = CallAttrs.requiresLazySave();
9123+
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91219124
if (RequiresLazySave) {
91229125
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91239126
MachinePointerInfo MPI =
@@ -9145,18 +9148,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91459148
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91469149
});
91479150
} else if (RequiresSaveAllZA) {
9148-
assert(!CalleeAttrs.hasSharedZAInterface() &&
9151+
assert(!CallAttrs.callee().hasSharedZAInterface() &&
91499152
"Cannot share state that may not exist");
91509153
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91519154
/*IsSave=*/true);
91529155
}
91539156

91549157
SDValue PStateSM;
9155-
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
9158+
bool RequiresSMChange = CallAttrs.requiresSMChange();
91569159
if (RequiresSMChange) {
9157-
if (CallerAttrs.hasStreamingInterfaceOrBody())
9160+
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
91589161
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9159-
else if (CallerAttrs.hasNonStreamingInterface())
9162+
else if (CallAttrs.caller().hasNonStreamingInterface())
91609163
PStateSM = DAG.getConstant(0, DL, MVT::i64);
91619164
else
91629165
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9173,7 +9176,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91739176

91749177
SDValue ZTFrameIdx;
91759178
MachineFrameInfo &MFI = MF.getFrameInfo();
9176-
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
9179+
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
91779180

91789181
// If the caller has ZT0 state which will not be preserved by the callee,
91799182
// spill ZT0 before the call.
@@ -9189,7 +9192,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91899192

91909193
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
91919194
// PSTATE.ZA before the call if there is no lazy-save active.
9192-
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
9195+
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
91939196
assert((!DisableZA || !RequiresLazySave) &&
91949197
"Lazy-save should have PSTATE.SM=1 on entry to the function");
91959198

@@ -9472,8 +9475,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94729475
}
94739476

94749477
SDValue NewChain = changeStreamingMode(
9475-
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
9476-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9478+
DAG, DL, CallAttrs.callee().hasStreamingInterface(), Chain, InGlue,
9479+
getSMToggleCondition(CallAttrs), PStateSM);
94779480
Chain = NewChain.getValue(0);
94789481
InGlue = NewChain.getValue(1);
94799482
}
@@ -9659,8 +9662,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96599662
if (RequiresSMChange) {
96609663
assert(PStateSM && "Expected a PStateSM to be set");
96619664
Result = changeStreamingMode(
9662-
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
9663-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9665+
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9666+
getSMToggleCondition(CallAttrs), PStateSM);
96649667

96659668
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96669669
InGlue = Result.getValue(1);
@@ -9670,7 +9673,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96709673
}
96719674
}
96729675

9673-
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
9676+
if (CallAttrs.requiresEnablingZAAfterCall())
96749677
// Unconditionally resume ZA.
96759678
Result = DAG.getNode(
96769679
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28559,12 +28562,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2855928562

2856028563
// Checks to allow the use of SME instructions
2856128564
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28562-
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28563-
auto CalleeAttrs = SMEAttrs(*Base);
28564-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28565-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28566-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28567-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28565+
auto CallAttrs = SMECallAttrs(*Base);
28566+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28567+
CallAttrs.requiresPreservingZT0() ||
28568+
CallAttrs.requiresPreservingAllZAState())
2856828569
return true;
2856928570
}
2857028571
return false;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
268268

269269
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270270
const Function *Callee) const {
271-
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
271+
SMECallAttrs CallAttrs(*Caller, *Callee);
272272

273273
// When inlining, we should consider the body of the function, not the
274274
// interface.
275-
if (CalleeAttrs.hasStreamingBody()) {
276-
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
277-
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
275+
if (CallAttrs.callee().hasStreamingBody()) {
276+
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277+
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278278
}
279279

280-
if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
280+
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281281
return false;
282282

283-
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
284-
CallerAttrs.requiresSMChange(CalleeAttrs) ||
285-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
286-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
283+
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284+
CallAttrs.requiresPreservingZT0() ||
285+
CallAttrs.requiresPreservingAllZAState()) {
287286
if (hasPossibleIncompatibleOps(Callee))
288287
return false;
289288
}
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
349348
// streaming-mode change, and the call to G from F would also require a
350349
// streaming-mode change, then there is benefit to do the streaming-mode
351350
// change only once and avoid inlining of G into F.
351+
352352
SMEAttrs FAttrs(*F);
353-
SMEAttrs CalleeAttrs(Call);
354-
if (FAttrs.requiresSMChange(CalleeAttrs)) {
353+
SMECallAttrs CallAttrs(Call);
354+
355+
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
355356
if (F == Call.getCaller()) // (1)
356357
return CallPenaltyChangeSM * DefaultCallPenalty;
357-
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
358+
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
358359
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
359360
}
360361

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

+38-29
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
2727
"ZA_New and SME_ABI_Routine are mutually exclusive");
2828

2929
assert(
30-
(!sharesZA() ||
31-
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
30+
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
3231
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
3332
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
3433

3534
// ZT0 Attrs
3635
assert(
37-
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
38-
isPreservesZT0())) &&
36+
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
37+
1 &&
3938
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
4039
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
4140

@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
4443
"interface");
4544
}
4645

47-
SMEAttrs::SMEAttrs(const CallBase &CB) {
48-
*this = SMEAttrs(CB.getAttributes());
49-
if (auto *F = CB.getCalledFunction()) {
50-
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
51-
}
52-
}
53-
54-
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
55-
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
56-
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
57-
if (FuncName == "__arm_tpidr2_restore")
58-
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
59-
SMEAttrs::SME_ABI_Routine;
60-
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
61-
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
62-
Bitmask |= SMEAttrs::SM_Compatible;
63-
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64-
FuncName == "__arm_sme_state_size")
65-
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
66-
}
67-
6846
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
6947
Bitmask = 0;
7048
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,48 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
9977
Bitmask |= encodeZT0State(StateValue::New);
10078
}
10179

102-
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
103-
if (Callee.hasStreamingCompatibleInterface())
80+
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
81+
unsigned KnownAttrs = SMEAttrs::Normal;
82+
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
83+
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
84+
if (FuncName == "__arm_tpidr2_restore")
85+
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
86+
SMEAttrs::SME_ABI_Routine;
87+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
88+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
89+
KnownAttrs |= SMEAttrs::SM_Compatible;
90+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
91+
FuncName == "__arm_sme_state_size")
92+
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
93+
set(KnownAttrs);
94+
}
95+
96+
bool SMECallAttrs::requiresSMChange() const {
97+
if (callee().hasStreamingCompatibleInterface())
10498
return false;
10599

106100
// Both non-streaming
107-
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
101+
if (caller().hasNonStreamingInterfaceAndBody() &&
102+
callee().hasNonStreamingInterface())
108103
return false;
109104

110105
// Both streaming
111-
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
106+
if (caller().hasStreamingInterfaceOrBody() &&
107+
callee().hasStreamingInterface())
112108
return false;
113109

114110
return true;
115111
}
112+
113+
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114+
: CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal),
115+
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
116+
if (auto *CalledFunction = CB.getCalledFunction())
117+
CalledFn = SMEAttrs(*CalledFunction, SMEAttrs::InferAttrsFromName::Yes);
118+
119+
// FIXME: We probably should not allow SME attributes on direct calls but
120+
// clang duplicates streaming mode attributes at each callsite.
121+
assert((IsIndirect ||
122+
((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
123+
"SME attributes at callsite do not match declaration");
124+
}

0 commit comments

Comments
 (0)