diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 8afe360d088bc..6060ab3f76d50 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2859,6 +2859,9 @@ void Verifier::visitFunction(const Function &F) { Check(!Attrs.hasAttrSomewhere(Attribute::ElementType), "Attribute 'elementtype' can only be applied to a callsite.", &F); + Check(!Attrs.hasFnAttr("aarch64_zt0_undef"), + "Attribute 'aarch64_zt0_undef' can only be applied to a callsite."); + if (Attrs.hasFnAttr(Attribute::Naked)) for (const Argument &Arg : F.args()) Check(Arg.use_empty(), "cannot use argument of naked function", &Arg); diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp index bb885d86392fe..b6685497e1fd1 100644 --- a/llvm/lib/Target/AArch64/SMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp @@ -54,14 +54,22 @@ FunctionPass *llvm::createSMEABIPass() { return new SMEABI(); } //===----------------------------------------------------------------------===// // Utility function to emit a call to __arm_tpidr2_save and clear TPIDR2_EL0. -void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) { +void emitTPIDR2Save(Module *M, IRBuilder<> &Builder, bool ZT0IsUndef = false) { + auto &Ctx = M->getContext(); auto *TPIDR2SaveTy = FunctionType::get(Builder.getVoidTy(), {}, /*IsVarArgs=*/false); - auto Attrs = AttributeList().addFnAttribute(M->getContext(), - "aarch64_pstate_sm_compatible"); + auto Attrs = + AttributeList().addFnAttribute(Ctx, "aarch64_pstate_sm_compatible"); FunctionCallee Callee = M->getOrInsertFunction("__arm_tpidr2_save", TPIDR2SaveTy, Attrs); CallInst *Call = Builder.CreateCall(Callee); + + // If ZT0 is undefined (i.e. we're at the entry of a "new_zt0" function), mark + // that on the __arm_tpidr2_save call. This prevents an unnecessary spill of + // ZT0 that can occur before ZA is enabled. + if (ZT0IsUndef) + Call->addFnAttr(Attribute::get(Ctx, "aarch64_zt0_undef")); + Call->setCallingConv( CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0); @@ -119,7 +127,7 @@ bool SMEABI::updateNewStateFunctions(Module *M, Function *F, // Create a call __arm_tpidr2_save, which commits the lazy save. Builder.SetInsertPoint(&SaveBB->back()); - emitTPIDR2Save(M, Builder); + emitTPIDR2Save(M, Builder, /*ZT0IsUndef=*/FnAttrs.isNewZT0()); // Enable pstate.za at the start of the function. Builder.SetInsertPoint(&OrigBB->front()); diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index bf16acd7f8f7e..76d2ac6a601e5 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -75,6 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { Bitmask |= SM_Body; if (Attrs.hasFnAttr("aarch64_za_state_agnostic")) Bitmask |= ZA_State_Agnostic; + if (Attrs.hasFnAttr("aarch64_zt0_undef")) + Bitmask |= ZT0_Undef; if (Attrs.hasFnAttr("aarch64_in_za")) Bitmask |= encodeZAState(StateValue::In); if (Attrs.hasFnAttr("aarch64_out_za")) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index a3ebf764a6e0c..1691d4fec8b68 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -43,9 +43,10 @@ class SMEAttrs { SM_Body = 1 << 2, // aarch64_pstate_sm_body SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves ZA_State_Agnostic = 1 << 4, - ZA_Shift = 5, + ZT0_Undef = 1 << 5, // Use to mark ZT0 as undef to avoid spills + ZA_Shift = 6, ZA_Mask = 0b111 << ZA_Shift, - ZT0_Shift = 8, + ZT0_Shift = 9, ZT0_Mask = 0b111 << ZT0_Shift }; @@ -125,6 +126,7 @@ class SMEAttrs { bool isPreservesZT0() const { return decodeZT0State(Bitmask) == StateValue::Preserved; } + bool isUndefZT0() const { return Bitmask & ZT0_Undef; } bool sharesZT0() const { StateValue State = decodeZT0State(Bitmask); return State == StateValue::In || State == StateValue::Out || @@ -132,9 +134,8 @@ class SMEAttrs { } bool hasZT0State() const { return isNewZT0() || sharesZT0(); } bool requiresPreservingZT0(const SMEAttrs &Callee) const { - return hasZT0State() && !Callee.sharesZT0() && - !Callee.hasAgnosticZAInterface() && - !(Callee.Bitmask & SME_ABI_Routine); + return hasZT0State() && !Callee.isUndefZT0() && !Callee.sharesZT0() && + !Callee.hasAgnosticZAInterface(); } bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const { return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() && diff --git a/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll new file mode 100644 index 0000000000000..94968ab4fd9ac --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-new-zt0-function.ll @@ -0,0 +1,14 @@ +; RUN: opt -S -mtriple=aarch64-linux-gnu -aarch64-sme-abi %s | FileCheck %s + +declare void @callee(); + +define void @private_za() "aarch64_new_zt0" { + call void @callee() + ret void +} + +; CHECK: call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save() #[[TPIDR2_SAVE_CALL_ATTR:[0-9]+]] +; CHECK: declare void @__arm_tpidr2_save() #[[TPIDR2_SAVE_DECL_ATTR:[0-9]+]] + +; CHECK: attributes #[[TPIDR2_SAVE_DECL_ATTR]] = { "aarch64_pstate_sm_compatible" } +; CHECK: attributes #[[TPIDR2_SAVE_CALL_ATTR]] = { "aarch64_zt0_undef" } diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll index 500fff4eb20db..7361e850d713e 100644 --- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll +++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll @@ -167,6 +167,39 @@ define void @zt0_new_caller_zt0_new_callee() "aarch64_new_zt0" nounwind { ret void; } +; Expect commit of lazy-save if ZA is dormant +; Expect smstart ZA & clear ZT0 +; No spill & fill of ZT0 around __arm_tpidr2_save +; Expect spill & fill of ZT0 around __arm_sme_state call +; Before return, expect smstop ZA +define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind { +; CHECK-LABEL: zt0_new_caller_abi_routine_callee: +; CHECK: // %bb.0: // %prelude +; CHECK-NEXT: sub sp, sp, #80 +; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: mrs x8, TPIDR2_EL0 +; CHECK-NEXT: cbz x8, .LBB7_2 +; CHECK-NEXT: // %bb.1: // %save.za +; CHECK-NEXT: bl __arm_tpidr2_save +; CHECK-NEXT: msr TPIDR2_EL0, xzr +; CHECK-NEXT: .LBB7_2: +; CHECK-NEXT: smstart za +; CHECK-NEXT: zero { zt0 } +; CHECK-NEXT: mov x19, sp +; CHECK-NEXT: str zt0, [x19] +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: ldr zt0, [x19] +; CHECK-NEXT: smstop za +; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: add sp, sp, #80 +; CHECK-NEXT: ret + %res = call {i64, i64} @__arm_sme_state() + %res.0 = extractvalue {i64, i64} %res, 0 + ret i64 %res.0 +} + +declare {i64, i64} @__arm_sme_state() + ; ; New-ZA Caller ; @@ -179,11 +212,11 @@ define void @zt0_new_caller() "aarch64_new_zt0" nounwind { ; CHECK: // %bb.0: // %prelude ; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill ; CHECK-NEXT: mrs x8, TPIDR2_EL0 -; CHECK-NEXT: cbz x8, .LBB7_2 +; CHECK-NEXT: cbz x8, .LBB8_2 ; CHECK-NEXT: // %bb.1: // %save.za ; CHECK-NEXT: bl __arm_tpidr2_save ; CHECK-NEXT: msr TPIDR2_EL0, xzr -; CHECK-NEXT: .LBB7_2: +; CHECK-NEXT: .LBB8_2: ; CHECK-NEXT: smstart za ; CHECK-NEXT: zero { zt0 } ; CHECK-NEXT: bl callee @@ -202,11 +235,11 @@ define void @new_za_zt0_caller() "aarch64_new_za" "aarch64_new_zt0" nounwind { ; CHECK: // %bb.0: // %prelude ; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill ; CHECK-NEXT: mrs x8, TPIDR2_EL0 -; CHECK-NEXT: cbz x8, .LBB8_2 +; CHECK-NEXT: cbz x8, .LBB9_2 ; CHECK-NEXT: // %bb.1: // %save.za ; CHECK-NEXT: bl __arm_tpidr2_save ; CHECK-NEXT: msr TPIDR2_EL0, xzr -; CHECK-NEXT: .LBB8_2: +; CHECK-NEXT: .LBB9_2: ; CHECK-NEXT: smstart za ; CHECK-NEXT: zero {za} ; CHECK-NEXT: zero { zt0 } diff --git a/llvm/test/Verifier/sme-attributes.ll b/llvm/test/Verifier/sme-attributes.ll index 4bf5e813daf2f..0ae2b9fd91f52 100644 --- a/llvm/test/Verifier/sme-attributes.ll +++ b/llvm/test/Verifier/sme-attributes.ll @@ -68,3 +68,6 @@ declare void @zt0_inout_out() "aarch64_inout_zt0" "aarch64_out_zt0"; declare void @zt0_inout_agnostic() "aarch64_inout_zt0" "aarch64_za_state_agnostic"; ; CHECK: Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', 'aarch64_inout_zt0', 'aarch64_preserves_zt0' and 'aarch64_za_state_agnostic' are mutually exclusive + +declare void @zt0_undef_function() "aarch64_zt0_undef"; +; CHECK: Attribute 'aarch64_zt0_undef' can only be applied to a callsite. diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp index 3af5e24168c8c..f8c77fcba19cf 100644 --- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp +++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp @@ -1,6 +1,7 @@ #include "Utils/AArch64SMEAttributes.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" #include "llvm/Support/SourceMgr.h" @@ -69,6 +70,15 @@ TEST(SMEAttributes, Constructors) { ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_new_zt0\"") ->getFunction("foo")) .isNewZT0()); + ASSERT_TRUE( + SA(cast((parseIR("declare void @callee()\n" + "define void @foo() {" + "call void @callee() \"aarch64_zt0_undef\"\n" + "ret void\n}") + ->getFunction("foo") + ->begin() + ->front()))) + .isUndefZT0()); // Invalid combinations. EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible), @@ -215,6 +225,18 @@ TEST(SMEAttributes, Basics) { ASSERT_FALSE(ZT0_New.hasSharedZAInterface()); ASSERT_TRUE(ZT0_New.hasPrivateZAInterface()); + SA ZT0_Undef = SA(SA::ZT0_Undef | SA::encodeZT0State(SA::StateValue::New)); + ASSERT_TRUE(ZT0_Undef.isNewZT0()); + ASSERT_FALSE(ZT0_Undef.isInZT0()); + ASSERT_FALSE(ZT0_Undef.isOutZT0()); + ASSERT_FALSE(ZT0_Undef.isInOutZT0()); + ASSERT_FALSE(ZT0_Undef.isPreservesZT0()); + ASSERT_FALSE(ZT0_Undef.sharesZT0()); + ASSERT_TRUE(ZT0_Undef.hasZT0State()); + ASSERT_FALSE(ZT0_Undef.hasSharedZAInterface()); + ASSERT_TRUE(ZT0_Undef.hasPrivateZAInterface()); + ASSERT_TRUE(ZT0_Undef.isUndefZT0()); + ASSERT_FALSE(SA(SA::Normal).isInZT0()); ASSERT_FALSE(SA(SA::Normal).isOutZT0()); ASSERT_FALSE(SA(SA::Normal).isInOutZT0()); @@ -285,6 +307,7 @@ TEST(SMEAttributes, Transitions) { SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In)); SA ZA_ZT0_Shared = SA(SA::encodeZAState(SA::StateValue::In) | SA::encodeZT0State(SA::StateValue::In)); + SA Undef_ZT0 = SA(SA::ZT0_Undef); // Shared ZA -> Private ZA Interface ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA)); @@ -295,6 +318,13 @@ TEST(SMEAttributes, Transitions) { ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA)); ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA)); + // Shared Undef ZT0 -> Private ZA Interface + // Note: "Undef ZT0" is a callsite attribute that means ZT0 is undefined at + // point the of the call. + ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Undef_ZT0)); + ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(Undef_ZT0)); + ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Undef_ZT0)); + // Shared ZA & ZT0 -> Private ZA Interface ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA)); ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));