Skip to content

Commit 13d04fa

Browse files
committed
[DAG] Add legalization handling for ABDS/ABDU (#92576) (REAPPLIED)
Always match ABD patterns pre-legalization, and use TargetLowering::expandABD to expand again during legalization. abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), usub_overflow(lhs, rhs)), usub_overflow(lhs, rhs)) Alive2: https://alive2.llvm.org/ce/z/dVdMyv REAPPLIED: Fix regression issue with "abs(ext(x) - ext(y)) -> zext(abd(x, y))" fold failing after type legalization
1 parent c4e7728 commit 13d04fa

29 files changed

+3206
-4084
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4091,13 +4091,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
40914091
}
40924092

40934093
// smax(a,b) - smin(a,b) --> abds(a,b)
4094-
if (hasOperation(ISD::ABDS, VT) &&
4094+
if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
40954095
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
40964096
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
40974097
return DAG.getNode(ISD::ABDS, DL, VT, A, B);
40984098

40994099
// umax(a,b) - umin(a,b) --> abdu(a,b)
4100-
if (hasOperation(ISD::ABDU, VT) &&
4100+
if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
41014101
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
41024102
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
41034103
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
@@ -5263,6 +5263,10 @@ SDValue DAGCombiner::visitABD(SDNode *N) {
52635263
if (N0.isUndef() || N1.isUndef())
52645264
return DAG.getConstant(0, DL, VT);
52655265

5266+
// fold (abd x, x) -> 0
5267+
if (N0 == N1)
5268+
return DAG.getConstant(0, DL, VT);
5269+
52665270
SDValue X;
52675271

52685272
// fold (abds x, 0) -> abs x
@@ -10924,6 +10928,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1092410928
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
1092510929
Opc0 != ISD::SIGN_EXTEND_INREG)) {
1092610930
// fold (abs (sub nsw x, y)) -> abds(x, y)
10931+
// Don't fold this for unsupported types as we lose the NSW handling.
1092710932
if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
1092810933
TLI.preferABDSToABSWithNSW(VT)) {
1092910934
SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
@@ -10946,7 +10951,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1094610951
// fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
1094710952
EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
1094810953
if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10949-
(VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
10954+
(VT1 == MaxVT || Op1->hasOneUse()) &&
10955+
(!LegalTypes || hasOperation(ABDOpcode, MaxVT))) {
1095010956
SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
1095110957
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
1095210958
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
@@ -10956,7 +10962,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1095610962

1095710963
// fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
1095810964
// fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10959-
if (hasOperation(ABDOpcode, VT)) {
10965+
if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
1096010966
SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
1096110967
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
1096210968
}
@@ -11580,7 +11586,7 @@ SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
1158011586
unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
1158111587
EVT VT = LHS.getValueType();
1158211588

11583-
if (!hasOperation(ABDOpc, VT))
11589+
if (LegalOperations && !hasOperation(ABDOpc, VT))
1158411590
return SDValue();
1158511591

1158611592
switch (CC) {

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
192192
case ISD::VP_SUB:
193193
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
194194

195+
case ISD::ABDS:
195196
case ISD::AVGCEILS:
196197
case ISD::AVGFLOORS:
197198
case ISD::VP_SMIN:
@@ -201,6 +202,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
201202
case ISD::VP_SDIV:
202203
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
203204

205+
case ISD::ABDU:
204206
case ISD::AVGCEILU:
205207
case ISD::AVGFLOORU:
206208
case ISD::VP_UMIN:
@@ -2791,6 +2793,8 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
27912793
case ISD::PARITY: ExpandIntRes_PARITY(N, Lo, Hi); break;
27922794
case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break;
27932795
case ISD::ABS: ExpandIntRes_ABS(N, Lo, Hi); break;
2796+
case ISD::ABDS:
2797+
case ISD::ABDU: ExpandIntRes_ABD(N, Lo, Hi); break;
27942798
case ISD::CTLZ_ZERO_UNDEF:
27952799
case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break;
27962800
case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break;
@@ -3850,6 +3854,11 @@ void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N,
38503854
Hi = DAG.getConstant(0, dl, NVT);
38513855
}
38523856

3857+
void DAGTypeLegalizer::ExpandIntRes_ABD(SDNode *N, SDValue &Lo, SDValue &Hi) {
3858+
SDValue Result = TLI.expandABD(N, DAG);
3859+
SplitInteger(Result, Lo, Hi);
3860+
}
3861+
38533862
void DAGTypeLegalizer::ExpandIntRes_CTPOP(SDNode *N,
38543863
SDValue &Lo, SDValue &Hi) {
38553864
SDLoc dl(N);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
448448
void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi);
449449
void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi);
450450
void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi);
451+
void ExpandIntRes_ABD (SDNode *N, SDValue &Lo, SDValue &Hi);
451452
void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi);
452453
void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi);
453454
void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
147147
case ISD::FMINIMUM:
148148
case ISD::FMAXIMUM:
149149
case ISD::FLDEXP:
150+
case ISD::ABDS:
151+
case ISD::ABDU:
150152
case ISD::SMIN:
151153
case ISD::SMAX:
152154
case ISD::UMIN:
@@ -1233,6 +1235,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
12331235
case ISD::MUL: case ISD::VP_MUL:
12341236
case ISD::MULHS:
12351237
case ISD::MULHU:
1238+
case ISD::ABDS:
1239+
case ISD::ABDU:
12361240
case ISD::AVGCEILS:
12371241
case ISD::AVGCEILU:
12381242
case ISD::AVGFLOORS:
@@ -4368,6 +4372,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
43684372
case ISD::MUL: case ISD::VP_MUL:
43694373
case ISD::MULHS:
43704374
case ISD::MULHU:
4375+
case ISD::ABDS:
4376+
case ISD::ABDU:
43714377
case ISD::OR: case ISD::VP_OR:
43724378
case ISD::SUB: case ISD::VP_SUB:
43734379
case ISD::XOR: case ISD::VP_XOR:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7024,6 +7024,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
70247024
assert(VT.isInteger() && "This operator does not apply to FP types!");
70257025
assert(N1.getValueType() == N2.getValueType() &&
70267026
N1.getValueType() == VT && "Binary operator types must match!");
7027+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
7028+
return getNode(ISD::XOR, DL, VT, N1, N2);
70277029
break;
70287030
case ISD::SMIN:
70297031
case ISD::UMAX:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9311,6 +9311,21 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
93119311
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
93129312
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
93139313

9314+
// If the subtract doesn't overflow then just use abs(sub())
9315+
// NOTE: don't use frozen operands for value tracking.
9316+
bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
9317+
DAG.SignBitIsZero(N->getOperand(0));
9318+
9319+
if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
9320+
N->getOperand(1)))
9321+
return DAG.getNode(ISD::ABS, dl, VT,
9322+
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));
9323+
9324+
if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
9325+
N->getOperand(0)))
9326+
return DAG.getNode(ISD::ABS, dl, VT,
9327+
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
9328+
93149329
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
93159330
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
93169331
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
@@ -9324,6 +9339,23 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
93249339
return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
93259340
}
93269341

9342+
// Similar to the branchless expansion, use the (sign-extended) usubo overflow
9343+
// flag if the (scalar) type is illegal as this is more likely to legalize
9344+
// cleanly:
9345+
// abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), uof(lhs, rhs)), uof(lhs, rhs))
9346+
if (!IsSigned && VT.isScalarInteger() && !isTypeLegal(VT)) {
9347+
SDValue USubO =
9348+
DAG.getNode(ISD::USUBO, dl, DAG.getVTList(VT, MVT::i1), {LHS, RHS});
9349+
SDValue Cmp = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, USubO.getValue(1));
9350+
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, USubO.getValue(0), Cmp);
9351+
return DAG.getNode(ISD::SUB, dl, VT, Xor, Cmp);
9352+
}
9353+
9354+
// FIXME: Should really try to split the vector in case it's legal on a
9355+
// subvector.
9356+
if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9357+
return DAG.UnrollVectorOp(N);
9358+
93279359
// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
93289360
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
93299361
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),

0 commit comments

Comments
 (0)