Skip to content

Commit 14e9772

Browse files
committed
[AArch64][SME] Disallow SME attributes on direct function calls
This was only used in a handful of tests (mainly to avoid making multiple function declarations). These tests can easily be updated to use indirect calls or attributes on declarations. This allows us to remove checks that looked at both the "callee" and "callsite" attributes, which makes the API of SMECallAttrs a clearer and less error-prone (as you can't accidentally use .callee() when you should have used .calleeOrCallsite()). Note: This currently still allows non-conflicting attributes on direct calls (as clang currently duplicates streaming mode attributes at each callsite).
1 parent 20fc9f1 commit 14e9772

File tree

7 files changed

+100
-88
lines changed

7 files changed

+100
-88
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -8979,9 +8979,9 @@ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
89798979
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
89808980
CallAttrs.caller().hasStreamingBody())
89818981
return AArch64SME::Always;
8982-
if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
8982+
if (CallAttrs.callee().hasNonStreamingInterface())
89838983
return AArch64SME::IfCallerIsStreaming;
8984-
if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
8984+
if (CallAttrs.callee().hasStreamingInterface())
89858985
return AArch64SME::IfCallerIsNonStreaming;
89868986

89878987
llvm_unreachable("Unsupported attributes");
@@ -9158,7 +9158,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91589158
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91599159
});
91609160
} else if (RequiresSaveAllZA) {
9161-
assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
9161+
assert(!CallAttrs.callee().hasSharedZAInterface() &&
91629162
"Cannot share state that may not exist");
91639163
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91649164
/*IsSave=*/true);
@@ -9484,9 +9484,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94849484
InGlue = Chain.getValue(1);
94859485
}
94869486

9487-
SDValue NewChain = changeStreamingMode(
9488-
DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
9489-
InGlue, getSMCondition(CallAttrs), PStateSM);
9487+
SDValue NewChain =
9488+
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9489+
Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
94909490
Chain = NewChain.getValue(0);
94919491
InGlue = NewChain.getValue(1);
94929492
}
@@ -9665,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96659665
if (RequiresSMChange) {
96669666
assert(PStateSM && "Expected a PStateSM to be set");
96679667
Result = changeStreamingMode(
9668-
DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
9669-
InGlue, getSMCondition(CallAttrs), PStateSM);
9668+
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9669+
getSMCondition(CallAttrs), PStateSM);
96709670

96719671
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96729672
InGlue = Result.getValue(1);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
352352
SMEAttrs FAttrs(*F);
353353
SMECallAttrs CallAttrs(Call);
354354

355-
if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
355+
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
356356
if (F == Call.getCaller()) // (1)
357357
return CallPenaltyChangeSM * DefaultCallPenalty;
358358
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,28 @@ void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
9494
}
9595

9696
bool SMECallAttrs::requiresSMChange() const {
97-
if ((Callsite | Callee).hasStreamingCompatibleInterface())
97+
if (callee().hasStreamingCompatibleInterface())
9898
return false;
9999

100100
// Both non-streaming
101-
if (Caller.hasNonStreamingInterfaceAndBody() &&
102-
(Callsite | Callee).hasNonStreamingInterface())
101+
if (caller().hasNonStreamingInterfaceAndBody() &&
102+
callee().hasNonStreamingInterface())
103103
return false;
104104

105105
// Both streaming
106-
if (Caller.hasStreamingInterfaceOrBody() &&
107-
(Callsite | Callee).hasStreamingInterface())
106+
if (caller().hasStreamingInterfaceOrBody() &&
107+
callee().hasStreamingInterface())
108108
return false;
109109

110110
return true;
111111
}
112112

113113
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114-
: SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
115-
CB.getAttributes()) {}
114+
: CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
115+
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
116+
// FIXME: We probably should not allow SME attributes on direct calls but
117+
// clang currently copies attributes from the declaration to each callsite.
118+
assert((IsIndirect || Callsite.withoutPerCallsiteFlags().isNormal() ||
119+
Callsite.withoutPerCallsiteFlags() == CalledFn) &&
120+
"SME attributes at callsite do not match declaration");
121+
}

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h

+31-24
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class SMEAttrs {
4444
ZA_Shift = 6,
4545
ZA_Mask = 0b111 << ZA_Shift,
4646
ZT0_Shift = 9,
47-
ZT0_Mask = 0b111 << ZT0_Shift
47+
ZT0_Mask = 0b111 << ZT0_Shift,
48+
Callsite_Flags = ZT0_Undef
4849
};
4950

5051
SMEAttrs() = default;
@@ -129,10 +130,13 @@ class SMEAttrs {
129130
}
130131
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
131132

132-
SMEAttrs operator|(SMEAttrs Other) const {
133-
SMEAttrs Merged(*this);
134-
Merged.set(Other.Bitmask, /*Enable=*/true);
135-
return Merged;
133+
bool isNormal() const { return Bitmask == Normal; }
134+
SMEAttrs withoutPerCallsiteFlags() const {
135+
return (Bitmask & ~Callsite_Flags);
136+
}
137+
138+
bool operator==(SMEAttrs const &Other) const {
139+
return Bitmask == Other.Bitmask;
136140
}
137141

138142
private:
@@ -143,54 +147,57 @@ class SMEAttrs {
143147
/// interfaces to query whether a streaming mode change or lazy-save mechanism
144148
/// is required when going from one function to another (e.g. through a call).
145149
class SMECallAttrs {
146-
SMEAttrs Caller;
147-
SMEAttrs Callee;
150+
SMEAttrs CallerFn;
151+
SMEAttrs CalledFn;
148152
SMEAttrs Callsite;
153+
bool IsIndirect = false;
149154

150155
public:
151156
SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
152157
SMEAttrs Callsite = SMEAttrs::Normal)
153-
: Caller(Caller), Callee(Callee), Callsite(Callsite) {}
158+
: CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
154159

155160
SMECallAttrs(const CallBase &CB);
156161

157-
SMEAttrs &caller() { return Caller; }
158-
SMEAttrs &callee() { return Callee; }
162+
SMEAttrs &caller() { return CallerFn; }
163+
SMEAttrs &callee() {
164+
if (IsIndirect)
165+
return Callsite;
166+
return CalledFn;
167+
}
159168
SMEAttrs &callsite() { return Callsite; }
160-
SMEAttrs const &caller() const { return Caller; }
161-
SMEAttrs const &callee() const { return Callee; }
169+
SMEAttrs const &caller() const { return CallerFn; }
170+
SMEAttrs const &callee() const {
171+
return const_cast<SMECallAttrs *>(this)->callee();
172+
}
162173
SMEAttrs const &callsite() const { return Callsite; }
163-
SMEAttrs calleeOrCallsite() const { return Callsite | Callee; }
164174

165175
/// \return true if a call from Caller -> Callee requires a change in
166176
/// streaming mode.
167177
bool requiresSMChange() const;
168178

169179
bool requiresLazySave() const {
170-
return Caller.hasZAState() && (Callsite | Callee).hasPrivateZAInterface() &&
171-
!Callee.isSMEABIRoutine();
180+
return caller().hasZAState() && callee().hasPrivateZAInterface() &&
181+
!callee().isSMEABIRoutine();
172182
}
173183

174184
bool requiresPreservingZT0() const {
175-
return Caller.hasZT0State() && !Callsite.hasUndefZT0() &&
176-
!(Callsite | Callee).sharesZT0() &&
177-
!(Callsite | Callee).hasAgnosticZAInterface();
185+
return caller().hasZT0State() && !callsite().hasUndefZT0() &&
186+
!callee().sharesZT0() && !callee().hasAgnosticZAInterface();
178187
}
179188

180189
bool requiresDisablingZABeforeCall() const {
181-
return Caller.hasZT0State() && !Caller.hasZAState() &&
182-
(Callsite | Callee).hasPrivateZAInterface() &&
183-
!Callee.isSMEABIRoutine();
190+
return caller().hasZT0State() && !caller().hasZAState() &&
191+
callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
184192
}
185193

186194
bool requiresEnablingZAAfterCall() const {
187195
return requiresLazySave() || requiresDisablingZABeforeCall();
188196
}
189197

190198
bool requiresPreservingAllZAState() const {
191-
return Caller.hasAgnosticZAInterface() &&
192-
!(Callsite | Callee).hasAgnosticZAInterface() &&
193-
!Callee.isSMEABIRoutine();
199+
return caller().hasAgnosticZAInterface() &&
200+
!callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
194201
}
195202
};
196203

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

+12-11
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-streaming-hazard-size=0 -mattr=+sve,+sme2 < %s | FileCheck %s
33

44
declare void @callee()
5+
declare void @callee_sm() "aarch64_pstate_sm_enabled"
56
declare void @callee_farg(float)
67
declare float @callee_farg_fret(float)
78

89
; normal caller -> streaming callees
9-
define void @test0() nounwind {
10+
define void @test0(ptr %callee) nounwind {
1011
; CHECK-LABEL: test0:
1112
; CHECK: // %bb.0:
1213
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
@@ -16,17 +17,17 @@ define void @test0() nounwind {
1617
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
1718
; CHECK-NEXT: stp x30, x9, [sp, #64] // 16-byte Folded Spill
1819
; CHECK-NEXT: smstart sm
19-
; CHECK-NEXT: bl callee
20-
; CHECK-NEXT: bl callee
20+
; CHECK-NEXT: bl callee_sm
21+
; CHECK-NEXT: bl callee_sm
2122
; CHECK-NEXT: smstop sm
2223
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
2324
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
2425
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
2526
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
2627
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
2728
; CHECK-NEXT: ret
28-
call void @callee() "aarch64_pstate_sm_enabled"
29-
call void @callee() "aarch64_pstate_sm_enabled"
29+
call void @callee_sm()
30+
call void @callee_sm()
3031
ret void
3132
}
3233

@@ -118,7 +119,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
118119
; CHECK-NEXT: // %bb.1:
119120
; CHECK-NEXT: smstart sm
120121
; CHECK-NEXT: .LBB3_2:
121-
; CHECK-NEXT: bl callee
122+
; CHECK-NEXT: bl callee_sm
122123
; CHECK-NEXT: tbnz w19, #0, .LBB3_4
123124
; CHECK-NEXT: // %bb.3:
124125
; CHECK-NEXT: smstop sm
@@ -140,7 +141,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
140141
; CHECK-NEXT: // %bb.9:
141142
; CHECK-NEXT: smstart sm
142143
; CHECK-NEXT: .LBB3_10:
143-
; CHECK-NEXT: bl callee
144+
; CHECK-NEXT: bl callee_sm
144145
; CHECK-NEXT: tbnz w19, #0, .LBB3_12
145146
; CHECK-NEXT: // %bb.11:
146147
; CHECK-NEXT: smstop sm
@@ -152,9 +153,9 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
152153
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
153154
; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
154155
; CHECK-NEXT: ret
155-
call void @callee() "aarch64_pstate_sm_enabled"
156+
call void @callee_sm()
156157
call void @callee()
157-
call void @callee() "aarch64_pstate_sm_enabled"
158+
call void @callee_sm()
158159
ret void
159160
}
160161

@@ -342,7 +343,7 @@ define void @test10() "aarch64_pstate_sm_body" {
342343
; CHECK-NEXT: bl callee
343344
; CHECK-NEXT: smstart sm
344345
; CHECK-NEXT: .cfi_restore vg
345-
; CHECK-NEXT: bl callee
346+
; CHECK-NEXT: bl callee_sm
346347
; CHECK-NEXT: .cfi_offset vg, -24
347348
; CHECK-NEXT: smstop sm
348349
; CHECK-NEXT: bl callee
@@ -363,7 +364,7 @@ define void @test10() "aarch64_pstate_sm_body" {
363364
; CHECK-NEXT: .cfi_restore b15
364365
; CHECK-NEXT: ret
365366
call void @callee()
366-
call void @callee() "aarch64_pstate_sm_enabled"
367+
call void @callee_sm()
367368
call void @callee()
368369
ret void
369370
}

llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -1098,11 +1098,11 @@ define void @test_rdsvl_right_after_prologue(i64 %x0) nounwind {
10981098
; NO-SVE-CHECK-NEXT: ret
10991099
%some_alloc = alloca i64, align 8
11001100
%rdsvl = tail call i64 @llvm.aarch64.sme.cntsd()
1101-
call void @bar(i64 %rdsvl, i64 %x0) "aarch64_pstate_sm_enabled"
1101+
call void @bar(i64 %rdsvl, i64 %x0)
11021102
ret void
11031103
}
11041104

1105-
declare void @bar(i64, i64)
1105+
declare void @bar(i64, i64) "aarch64_pstate_sm_enabled"
11061106

11071107
; Ensure we still emit async unwind information with -fno-asynchronous-unwind-tables
11081108
; if the function contains a streaming-mode change.

0 commit comments

Comments
 (0)