@@ -8641,6 +8641,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
8641
8641
}
8642
8642
}
8643
8643
8644
+ static SMECallAttrs
8645
+ getSMECallAttrs(const Function &Caller,
8646
+ const TargetLowering::CallLoweringInfo &CLI) {
8647
+ if (CLI.CB)
8648
+ return SMECallAttrs(*CLI.CB);
8649
+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8650
+ return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol()));
8651
+ return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal));
8652
+ }
8653
+
8644
8654
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8645
8655
const CallLoweringInfo &CLI) const {
8646
8656
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8659,12 +8669,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8659
8669
8660
8670
// SME Streaming functions are not eligible for TCO as they may require
8661
8671
// the streaming mode or ZA to be restored after returning from the call.
8662
- SMEAttrs CallerAttrs(MF.getFunction());
8663
- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8664
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8665
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
8666
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8667
- CallerAttrs.hasStreamingBody())
8672
+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8673
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8674
+ CallAttrs.requiresPreservingAllZAState() ||
8675
+ CallAttrs.caller().hasStreamingBody())
8668
8676
return false;
8669
8677
8670
8678
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8956,14 +8964,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8956
8964
return TLI.LowerCallTo(CLI).second;
8957
8965
}
8958
8966
8959
- static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8960
- const SMEAttrs &CalleeAttrs ) {
8961
- if (!CallerAttrs .hasStreamingCompatibleInterface() ||
8962
- CallerAttrs .hasStreamingBody())
8967
+ static AArch64SME::ToggleCondition
8968
+ getSMToggleCondition( const SMECallAttrs &CallAttrs ) {
8969
+ if (!CallAttrs.caller() .hasStreamingCompatibleInterface() ||
8970
+ CallAttrs.caller() .hasStreamingBody())
8963
8971
return AArch64SME::Always;
8964
- if (CalleeAttrs .hasNonStreamingInterface())
8972
+ if (CallAttrs.callee() .hasNonStreamingInterface())
8965
8973
return AArch64SME::IfCallerIsStreaming;
8966
- if (CalleeAttrs .hasStreamingInterface())
8974
+ if (CallAttrs.callee() .hasStreamingInterface())
8967
8975
return AArch64SME::IfCallerIsNonStreaming;
8968
8976
8969
8977
llvm_unreachable("Unsupported attributes");
@@ -9096,11 +9104,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9096
9104
}
9097
9105
9098
9106
// Determine whether we need any streaming mode changes.
9099
- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9100
- if (CLI.CB)
9101
- CalleeAttrs = SMEAttrs(*CLI.CB);
9102
- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9103
- CalleeAttrs = SMEAttrs(ES->getSymbol());
9107
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9104
9108
9105
9109
auto DescribeCallsite =
9106
9110
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9115,9 +9119,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9115
9119
return R;
9116
9120
};
9117
9121
9118
- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9119
- bool RequiresSaveAllZA =
9120
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9122
+ bool RequiresLazySave = CallAttrs.requiresLazySave();
9123
+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9121
9124
if (RequiresLazySave) {
9122
9125
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
9123
9126
MachinePointerInfo MPI =
@@ -9145,18 +9148,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9145
9148
return DescribeCallsite(R) << " sets up a lazy save for ZA";
9146
9149
});
9147
9150
} else if (RequiresSaveAllZA) {
9148
- assert(!CalleeAttrs .hasSharedZAInterface() &&
9151
+ assert(!CallAttrs.callee() .hasSharedZAInterface() &&
9149
9152
"Cannot share state that may not exist");
9150
9153
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9151
9154
/*IsSave=*/true);
9152
9155
}
9153
9156
9154
9157
SDValue PStateSM;
9155
- bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9158
+ bool RequiresSMChange = CallAttrs .requiresSMChange();
9156
9159
if (RequiresSMChange) {
9157
- if (CallerAttrs .hasStreamingInterfaceOrBody())
9160
+ if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
9158
9161
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9159
- else if (CallerAttrs .hasNonStreamingInterface())
9162
+ else if (CallAttrs.caller() .hasNonStreamingInterface())
9160
9163
PStateSM = DAG.getConstant(0, DL, MVT::i64);
9161
9164
else
9162
9165
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9173,7 +9176,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9173
9176
9174
9177
SDValue ZTFrameIdx;
9175
9178
MachineFrameInfo &MFI = MF.getFrameInfo();
9176
- bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9179
+ bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
9177
9180
9178
9181
// If the caller has ZT0 state which will not be preserved by the callee,
9179
9182
// spill ZT0 before the call.
@@ -9189,7 +9192,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9189
9192
9190
9193
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
9191
9194
// PSTATE.ZA before the call if there is no lazy-save active.
9192
- bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9195
+ bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
9193
9196
assert((!DisableZA || !RequiresLazySave) &&
9194
9197
"Lazy-save should have PSTATE.SM=1 on entry to the function");
9195
9198
@@ -9472,8 +9475,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9472
9475
}
9473
9476
9474
9477
SDValue NewChain = changeStreamingMode(
9475
- DAG, DL, CalleeAttrs .hasStreamingInterface(), Chain, InGlue,
9476
- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9478
+ DAG, DL, CallAttrs.callee() .hasStreamingInterface(), Chain, InGlue,
9479
+ getSMToggleCondition(CallAttrs ), PStateSM);
9477
9480
Chain = NewChain.getValue(0);
9478
9481
InGlue = NewChain.getValue(1);
9479
9482
}
@@ -9659,8 +9662,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9659
9662
if (RequiresSMChange) {
9660
9663
assert(PStateSM && "Expected a PStateSM to be set");
9661
9664
Result = changeStreamingMode(
9662
- DAG, DL, !CalleeAttrs .hasStreamingInterface(), Result, InGlue,
9663
- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9665
+ DAG, DL, !CallAttrs.callee() .hasStreamingInterface(), Result, InGlue,
9666
+ getSMToggleCondition(CallAttrs ), PStateSM);
9664
9667
9665
9668
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
9666
9669
InGlue = Result.getValue(1);
@@ -9670,7 +9673,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9670
9673
}
9671
9674
}
9672
9675
9673
- if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9676
+ if (CallAttrs .requiresEnablingZAAfterCall())
9674
9677
// Unconditionally resume ZA.
9675
9678
Result = DAG.getNode(
9676
9679
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28559,12 +28562,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
28559
28562
28560
28563
// Checks to allow the use of SME instructions
28561
28564
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28562
- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28563
- auto CalleeAttrs = SMEAttrs(*Base);
28564
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28565
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
28566
- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28567
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28565
+ auto CallAttrs = SMECallAttrs(*Base);
28566
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28567
+ CallAttrs.requiresPreservingZT0() ||
28568
+ CallAttrs.requiresPreservingAllZAState())
28568
28569
return true;
28569
28570
}
28570
28571
return false;
0 commit comments