Skip to content

Commit dd73666

Browse files
[SME] Stop RA from coalescing COPY instructions that transcend beyond smstart/smstop. (llvm#78294)
This patch introduces a 'COALESCER_BARRIER' which is a pseudo node that expands to a 'nop', but which stops the register allocator from coalescing a COPY node when its use/def crosses a SMSTART or SMSTOP instruction. For example: %0:fpr64 = COPY killed $d0 undef %2.dsub:zpr = COPY %0 // <- Do not coalesce this COPY ADJCALLSTACKDOWN 0, 0 MSRpstatesvcrImm1 1, 0, csr_aarch64_smstartstop, implicit-def dead $d0 $d0 = COPY killed %0 BL @use_f64, csr_aarch64_aapcs If the COPY would be coalesced, that would lead to: $d0 = COPY killed %0 being replaced by: $d0 = COPY killed %2.dsub which means the whole ZPR reg would be live upto the call, causing the MSRpstatesvcrImm1 (smstop) to spill/reload the ZPR register: str q0, [sp] // 16-byte Folded Spill smstop sm ldr z0, [sp] // 16-byte Folded Reload bl use_f64 which would be incorrect for two reasons: 1. The program may load more data than it has allocated. 2. If there are other SVE objects on the stack, the compiler might use the 'mul vl' addressing modes to access the spill location. By disabling the coalescing, we get the desired results: str d0, [sp, rust-lang#8] // 8-byte Folded Spill smstop sm ldr d0, [sp, rust-lang#8] // 8-byte Folded Reload bl use_f64
1 parent d439f36 commit dd73666

11 files changed

+1769
-29
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,12 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
15441544
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
15451545
return true;
15461546
}
1547+
case AArch64::COALESCER_BARRIER_FPR16:
1548+
case AArch64::COALESCER_BARRIER_FPR32:
1549+
case AArch64::COALESCER_BARRIER_FPR64:
1550+
case AArch64::COALESCER_BARRIER_FPR128:
1551+
MI.eraseFromParent();
1552+
return true;
15471553
case AArch64::LD1B_2Z_IMM_PSEUDO:
15481554
return expandMultiVecPseudo(
15491555
MBB, MBBI, AArch64::ZPR2RegClass, AArch64::ZPR2StridedRegClass,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+20-4
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
23752375
switch ((AArch64ISD::NodeType)Opcode) {
23762376
case AArch64ISD::FIRST_NUMBER:
23772377
break;
2378+
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
23782379
MAKE_CASE(AArch64ISD::SMSTART)
23792380
MAKE_CASE(AArch64ISD::SMSTOP)
23802381
MAKE_CASE(AArch64ISD::RESTORE_ZA)
@@ -7154,13 +7155,18 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
71547155
}
71557156
}
71567157

7158+
static bool isPassedInFPR(EVT VT) {
7159+
return VT.isFixedLengthVector() ||
7160+
(VT.isFloatingPoint() && !VT.isScalableVector());
7161+
}
7162+
71577163
/// LowerCallResult - Lower the result values of a call into the
71587164
/// appropriate copies out of appropriate physical registers.
71597165
SDValue AArch64TargetLowering::LowerCallResult(
71607166
SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
71617167
const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
71627168
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
7163-
SDValue ThisVal) const {
7169+
SDValue ThisVal, bool RequiresSMChange) const {
71647170
DenseMap<unsigned, SDValue> CopiedRegs;
71657171
// Copy all of the result registers out of their specified physreg.
71667172
for (unsigned i = 0; i != RVLocs.size(); ++i) {
@@ -7205,6 +7211,10 @@ SDValue AArch64TargetLowering::LowerCallResult(
72057211
break;
72067212
}
72077213

7214+
if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
7215+
Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
7216+
Val);
7217+
72087218
InVals.push_back(Val);
72097219
}
72107220

@@ -7915,6 +7925,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
79157925
return ArgReg.Reg == VA.getLocReg();
79167926
});
79177927
} else {
7928+
// Add an extra level of indirection for streaming mode changes by
7929+
// using a pseudo copy node that cannot be rematerialised between a
7930+
// smstart/smstop and the call by the simple register coalescer.
7931+
if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
7932+
Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
7933+
Arg.getValueType(), Arg);
79187934
RegsToPass.emplace_back(VA.getLocReg(), Arg);
79197935
RegsUsed.insert(VA.getLocReg());
79207936
const TargetOptions &Options = DAG.getTarget().Options;
@@ -8159,9 +8175,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
81598175

81608176
// Handle result values, copying them out of physregs into vregs that we
81618177
// return.
8162-
SDValue Result = LowerCallResult(Chain, InGlue, CallConv, IsVarArg, RVLocs,
8163-
DL, DAG, InVals, IsThisReturn,
8164-
IsThisReturn ? OutVals[0] : SDValue());
8178+
SDValue Result = LowerCallResult(
8179+
Chain, InGlue, CallConv, IsVarArg, RVLocs, DL, DAG, InVals, IsThisReturn,
8180+
IsThisReturn ? OutVals[0] : SDValue(), RequiresSMChange);
81658181

81668182
if (!Ins.empty())
81678183
InGlue = Result.getValue(Result->getNumValues() - 1);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ enum NodeType : unsigned {
5858

5959
CALL_BTI, // Function call followed by a BTI instruction.
6060

61+
COALESCER_BARRIER,
62+
6163
SMSTART,
6264
SMSTOP,
6365
RESTORE_ZA,
@@ -1026,7 +1028,7 @@ class AArch64TargetLowering : public TargetLowering {
10261028
const SmallVectorImpl<CCValAssign> &RVLocs,
10271029
const SDLoc &DL, SelectionDAG &DAG,
10281030
SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
1029-
SDValue ThisVal) const;
1031+
SDValue ThisVal, bool RequiresSMChange) const;
10301032

10311033
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
10321034
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,8 @@ bool AArch64RegisterInfo::shouldCoalesce(
10151015
MachineInstr *MI, const TargetRegisterClass *SrcRC, unsigned SubReg,
10161016
const TargetRegisterClass *DstRC, unsigned DstSubReg,
10171017
const TargetRegisterClass *NewRC, LiveIntervals &LIS) const {
1018+
MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
1019+
10181020
if (MI->isCopy() &&
10191021
((DstRC->getID() == AArch64::GPR64RegClassID) ||
10201022
(DstRC->getID() == AArch64::GPR64commonRegClassID)) &&
@@ -1023,5 +1025,38 @@ bool AArch64RegisterInfo::shouldCoalesce(
10231025
// which implements a 32 to 64 bit zero extension
10241026
// which relies on the upper 32 bits being zeroed.
10251027
return false;
1028+
1029+
auto IsCoalescerBarrier = [](const MachineInstr &MI) {
1030+
switch (MI.getOpcode()) {
1031+
case AArch64::COALESCER_BARRIER_FPR16:
1032+
case AArch64::COALESCER_BARRIER_FPR32:
1033+
case AArch64::COALESCER_BARRIER_FPR64:
1034+
case AArch64::COALESCER_BARRIER_FPR128:
1035+
return true;
1036+
default:
1037+
return false;
1038+
}
1039+
};
1040+
1041+
// For calls that temporarily have to toggle streaming mode as part of the
1042+
// call-sequence, we need to be more careful when coalescing copy instructions
1043+
// so that we don't end up coalescing the NEON/FP result or argument register
1044+
// with a whole Z-register, such that after coalescing the register allocator
1045+
// will try to spill/reload the entire Z register.
1046+
//
1047+
// We do this by checking if the node has any defs/uses that are
1048+
// COALESCER_BARRIER pseudos. These are 'nops' in practice, but they exist to
1049+
// instruct the coalescer to avoid coalescing the copy.
1050+
if (MI->isCopy() && SubReg != DstSubReg &&
1051+
(AArch64::ZPRRegClass.hasSubClassEq(DstRC) ||
1052+
AArch64::ZPRRegClass.hasSubClassEq(SrcRC))) {
1053+
unsigned SrcReg = MI->getOperand(1).getReg();
1054+
if (any_of(MRI.def_instructions(SrcReg), IsCoalescerBarrier))
1055+
return false;
1056+
unsigned DstReg = MI->getOperand(0).getReg();
1057+
if (any_of(MRI.use_nodbg_instructions(DstReg), IsCoalescerBarrier))
1058+
return false;
1059+
}
1060+
10261061
return true;
10271062
}

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

+22
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
2828
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
2929
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
3030
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
31+
def AArch64CoalescerBarrier
32+
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, []>;
3133

3234
//===----------------------------------------------------------------------===//
3335
// Instruction naming conventions.
@@ -189,6 +191,26 @@ def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
189191
(MSR 0xde85, GPR64:$val)>;
190192
def : Pat<(i64 (int_aarch64_sme_get_tpidr2)),
191193
(MRS 0xde85)>;
194+
195+
multiclass CoalescerBarrierPseudo<RegisterClass rc, list<ValueType> vts> {
196+
def NAME : Pseudo<(outs rc:$dst), (ins rc:$src), []>, Sched<[]> {
197+
let Constraints = "$dst = $src";
198+
}
199+
foreach vt = vts in {
200+
def : Pat<(vt (AArch64CoalescerBarrier (vt rc:$src))),
201+
(!cast<Instruction>(NAME) rc:$src)>;
202+
}
203+
}
204+
205+
multiclass CoalescerBarriers {
206+
defm _FPR16 : CoalescerBarrierPseudo<FPR16, [bf16, f16]>;
207+
defm _FPR32 : CoalescerBarrierPseudo<FPR32, [f32]>;
208+
defm _FPR64 : CoalescerBarrierPseudo<FPR64, [f64, v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16]>;
209+
defm _FPR128 : CoalescerBarrierPseudo<FPR128, [f128, v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16]>;
210+
}
211+
212+
defm COALESCER_BARRIER : CoalescerBarriers;
213+
192214
} // End let Predicates = [HasSME]
193215

194216
// Pseudo to match to smstart/smstop. This expands:

llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll

+12-8
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
2323
; CHECK-FISEL-NEXT: bl streaming_callee
2424
; CHECK-FISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
2525
; CHECK-FISEL-NEXT: smstop sm
26+
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2627
; CHECK-FISEL-NEXT: adrp x8, .LCPI0_0
2728
; CHECK-FISEL-NEXT: ldr d0, [x8, :lo12:.LCPI0_0]
28-
; CHECK-FISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
2929
; CHECK-FISEL-NEXT: fadd d0, d1, d0
3030
; CHECK-FISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
3131
; CHECK-FISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -49,9 +49,9 @@ define double @nonstreaming_caller_streaming_callee(double %x) nounwind noinline
4949
; CHECK-GISEL-NEXT: bl streaming_callee
5050
; CHECK-GISEL-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
5151
; CHECK-GISEL-NEXT: smstop sm
52+
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
5253
; CHECK-GISEL-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
5354
; CHECK-GISEL-NEXT: fmov d0, x8
54-
; CHECK-GISEL-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
5555
; CHECK-GISEL-NEXT: fadd d0, d1, d0
5656
; CHECK-GISEL-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
5757
; CHECK-GISEL-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -82,9 +82,9 @@ define double @streaming_caller_nonstreaming_callee(double %x) nounwind noinline
8282
; CHECK-COMMON-NEXT: bl normal_callee
8383
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
8484
; CHECK-COMMON-NEXT: smstart sm
85+
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
8586
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
8687
; CHECK-COMMON-NEXT: fmov d0, x8
87-
; CHECK-COMMON-NEXT: ldr d1, [sp, #8] // 8-byte Folded Reload
8888
; CHECK-COMMON-NEXT: fadd d0, d1, d0
8989
; CHECK-COMMON-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
9090
; CHECK-COMMON-NEXT: ldp d9, d8, [sp, #64] // 16-byte Folded Reload
@@ -110,14 +110,16 @@ define double @locally_streaming_caller_normal_callee(double %x) nounwind noinli
110110
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
111111
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
112112
; CHECK-COMMON-NEXT: smstart sm
113+
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
114+
; CHECK-COMMON-NEXT: str d0, [sp, #24] // 8-byte Folded Spill
113115
; CHECK-COMMON-NEXT: smstop sm
114116
; CHECK-COMMON-NEXT: ldr d0, [sp, #24] // 8-byte Folded Reload
115117
; CHECK-COMMON-NEXT: bl normal_callee
116118
; CHECK-COMMON-NEXT: str d0, [sp, #16] // 8-byte Folded Spill
117119
; CHECK-COMMON-NEXT: smstart sm
120+
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
118121
; CHECK-COMMON-NEXT: mov x8, #4631107791820423168 // =0x4045000000000000
119122
; CHECK-COMMON-NEXT: fmov d0, x8
120-
; CHECK-COMMON-NEXT: ldr d1, [sp, #16] // 8-byte Folded Reload
121123
; CHECK-COMMON-NEXT: fadd d0, d1, d0
122124
; CHECK-COMMON-NEXT: str d0, [sp, #8] // 8-byte Folded Spill
123125
; CHECK-COMMON-NEXT: smstop sm
@@ -329,9 +331,9 @@ define fp128 @f128_call_sm(fp128 %a, fp128 %b) "aarch64_pstate_sm_enabled" nounw
329331
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill
330332
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill
331333
; CHECK-COMMON-NEXT: str x30, [sp, #96] // 8-byte Folded Spill
332-
; CHECK-COMMON-NEXT: stp q0, q1, [sp] // 32-byte Folded Spill
334+
; CHECK-COMMON-NEXT: stp q1, q0, [sp] // 32-byte Folded Spill
333335
; CHECK-COMMON-NEXT: smstop sm
334-
; CHECK-COMMON-NEXT: ldp q0, q1, [sp] // 32-byte Folded Reload
336+
; CHECK-COMMON-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
335337
; CHECK-COMMON-NEXT: bl __addtf3
336338
; CHECK-COMMON-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
337339
; CHECK-COMMON-NEXT: smstart sm
@@ -390,9 +392,9 @@ define float @frem_call_sm(float %a, float %b) "aarch64_pstate_sm_enabled" nounw
390392
; CHECK-COMMON-NEXT: stp d11, d10, [sp, #48] // 16-byte Folded Spill
391393
; CHECK-COMMON-NEXT: stp d9, d8, [sp, #64] // 16-byte Folded Spill
392394
; CHECK-COMMON-NEXT: str x30, [sp, #80] // 8-byte Folded Spill
393-
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
395+
; CHECK-COMMON-NEXT: stp s1, s0, [sp, #8] // 8-byte Folded Spill
394396
; CHECK-COMMON-NEXT: smstop sm
395-
; CHECK-COMMON-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload
397+
; CHECK-COMMON-NEXT: ldp s1, s0, [sp, #8] // 8-byte Folded Reload
396398
; CHECK-COMMON-NEXT: bl fmodf
397399
; CHECK-COMMON-NEXT: str s0, [sp, #12] // 4-byte Folded Spill
398400
; CHECK-COMMON-NEXT: smstart sm
@@ -420,7 +422,9 @@ define float @frem_call_sm_compat(float %a, float %b) "aarch64_pstate_sm_compati
420422
; CHECK-COMMON-NEXT: stp x30, x19, [sp, #80] // 16-byte Folded Spill
421423
; CHECK-COMMON-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill
422424
; CHECK-COMMON-NEXT: bl __arm_sme_state
425+
; CHECK-COMMON-NEXT: ldp s2, s0, [sp, #8] // 8-byte Folded Reload
423426
; CHECK-COMMON-NEXT: and x19, x0, #0x1
427+
; CHECK-COMMON-NEXT: stp s2, s0, [sp, #8] // 8-byte Folded Spill
424428
; CHECK-COMMON-NEXT: tbz w19, #0, .LBB12_2
425429
; CHECK-COMMON-NEXT: // %bb.1:
426430
; CHECK-COMMON-NEXT: smstop sm

0 commit comments

Comments
 (0)