Skip to content

[AArch64][SME] Allow spills of ZT0 around SME ABI routines again #136726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions llvm/lib/Target/AArch64/SMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); }
//===----------------------------------------------------------------------===//

// Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0.
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) {
auto &Ctx = M->getContext();
auto *TPIDR2SaveTy =
FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false);
auto Attrs = AttributeList().addFnAttribute(M->getContext(),
"aarch64_pstate_sm_compatible");
auto Attrs =
AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible");
FunctionCallee Callee =
M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs);
CallInst *Call = Builder.CreateCall(Callee);

// If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark
// __arm_tpidr2_save as preserving ZT0. This prevents an unnecessary spill of
// ZT0 that can occur before ZA is enabled.
if (ZT0IsUndef)
Call->addFnAttr(Attribute::get(Ctx, "aarch64_preserves_zt0"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ACLE attribute __arm_preserves("zt0") maps to LLVM attribute aarch64_preserves_zt0.
__arm_preserves("zt0") means that the function has a "Shared-ZA" interface, which the SME ABI routines do not. I'm worried that we'd be abusing this attribute for a purpose that means something different, so I suggest introducing a new attribute for this instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I've added a new attribute aarch64_zt0_undef, which does not result in a "Shared-ZA" interface. I've also added a few extra tests and limited this attribute to only apply to callsites (as I'm not sure it'd make sense if applied to an entire function).


Call->setCallingConv(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0);

Expand Down Expand Up @@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F,

// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
emitTPIDR2Save(M, Builder);
emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0());

// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ class SMEAttrs {
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
return hasZT0State() && !Callee.sharesZT0() &&
!Callee.hasAgnosticZAInterface() &&
!(Callee.Bitmask & SME_ABI_Routine);
!Callee.hasAgnosticZAInterface();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
Expand Down
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s

declare void @callee();

define void @private_za() "aarch64_new_zt0" {
call void @callee()
ret void
}

; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]]
; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]]

; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" }
; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_preserves_zt0" }
41 changes: 37 additions & 4 deletions llvm/test/CodeGen/AArch64/sme-zt0-state.ll
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,39 @@ define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind {
ret void;
}

; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA & clear ZT0
; No spill & fill of ZT0 around __arm_tpidr2_save
; Expect spill & fill of ZT0 around __arm_sme_state call
; Before return, expect smstop ZA
define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller_abi_routine_callee:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB7_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
%res = call {i64, i64} @__arm_sme_state()
%res.0 = extractvalue {i64, i64} %res, 0
ret i64 %res.0
}

declare {i64, i64} @__arm_sme_state()

;
; New-ZA Caller
;
Expand All @@ -179,11 +212,11 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB7_2
; CHECK-NEXT: cbz x8, .LBB8_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
Expand All @@ -202,11 +235,11 @@ define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind {
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB8_2
; CHECK-NEXT: cbz x8, .LBB9_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB8_2:
; CHECK-NEXT: .LBB9_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
Expand Down
Loading