-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[X86] Improve KnownBits for X86ISD::PSADBW nodes #83830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-x86 Author: Simon Pilgrim (RKSimon) ChangesDon't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum Full diff: https://github.com/llvm/llvm-project/pull/83830.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b87e3121838dcc..5076ac5e347e9f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36836,12 +36836,23 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
break;
}
case X86ISD::PSADBW: {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
assert(VT.getScalarType() == MVT::i64 &&
- Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getScalarType() == MVT::i8 &&
"Unexpected PSADBW types");
- // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
- Known.Zero.setBitsFrom(16);
+ KnownBits Known2;
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
+ Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
+ Known = KnownBits::absdiff(Known, Known2).zext(16);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = Known.zext(64);
break;
}
case X86ISD::PCMPGT:
@@ -54853,6 +54864,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
MVT VT = N->getSimpleValueType(0);
SDLoc DL(N);
@@ -54864,6 +54876,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
return DAG.getConstant(0, DL, VT);
}
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (TLI.SimplifyDemandedBits(
+ SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
+ return SDValue(N, 0);
+
return SDValue();
}
@@ -56587,7 +56604,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
- case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
+ case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, DCI, Subtarget);
case X86ISD::PMULDQ:
case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI, Subtarget);
case X86ISD::VPMADDUBSW:
diff --git a/llvm/test/CodeGen/X86/psadbw.ll b/llvm/test/CodeGen/X86/psadbw.ll
index 8141b22d321f4d..8044472b13e3a8 100644
--- a/llvm/test/CodeGen/X86/psadbw.ll
+++ b/llvm/test/CodeGen/X86/psadbw.ll
@@ -70,10 +70,7 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
;
; AVX2-LABEL: combine_psadbw_cmp_knownbits:
; AVX2: # %bb.0:
-; AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT: vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX2-NEXT: vxorps %xmm0, %xmm0, %xmm0
; AVX2-NEXT: retq
%mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
%sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)
|
f5bb6c3
to
cf8f19b
Compare
cf8f19b
to
2bb14ec
Compare
2bb14ec
to
31c9bde
Compare
Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
31c9bde
to
8073a10
Compare
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true, | ||
Known, Known); | ||
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true, | ||
Known, Known); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Known = KnownBits::shl(Known, KnownBits::makeConstant(APInt(Known.getBitWidth(), 8)), /*NSW=*/true, /*NUW=*/true);
?
Does the 3x adds do a better job or something? Think our shl
and add
impl are both optimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than this, it all looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I was just trying to make it explicitly match the PSADBW expansion as possible, but I can just replace it with a shl by 3 (not 8).
Fun fact: KnownBits::shl doesn't need the shift amount to be the same bitwidth as the shift value :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would KnownBits::shl make the lower 3 bits of Known.Zero true? That would be wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you're right :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, noticed that when I tried it - I used the computeForAddSub chain instead
Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum.
Add implementations that handle both the X86ISD::PSADBW nodes and the INTRINSIC_WO_CHAIN intrinsics (pre-legalization).