Skip to content

Commit f5bb6c3

Browse files
committed
[X86] Improve KnownBits for X86ISD::PSADBW nodes
Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
1 parent 5dc9e87 commit f5bb6c3

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36836,12 +36836,23 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3683636836
break;
3683736837
}
3683836838
case X86ISD::PSADBW: {
36839+
SDValue LHS = Op.getOperand(0);
36840+
SDValue RHS = Op.getOperand(1);
3683936841
assert(VT.getScalarType() == MVT::i64 &&
36840-
Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
36842+
LHS.getValueType() == RHS.getValueType() &&
36843+
LHS.getValueType().getScalarType() == MVT::i8 &&
3684136844
"Unexpected PSADBW types");
3684236845

36843-
// PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
36844-
Known.Zero.setBitsFrom(16);
36846+
KnownBits Known2;
36847+
unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
36848+
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
36849+
Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
36850+
Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
36851+
Known = KnownBits::absdiff(Known, Known2).zext(16);
36852+
Known = KnownBits::computeForAddSub(true, true, Known, Known);
36853+
Known = KnownBits::computeForAddSub(true, true, Known, Known);
36854+
Known = KnownBits::computeForAddSub(true, true, Known, Known);
36855+
Known = Known.zext(64);
3684536856
break;
3684636857
}
3684736858
case X86ISD::PCMPGT:
@@ -54853,6 +54864,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
5485354864
}
5485454865

5485554866
static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
54867+
TargetLowering::DAGCombinerInfo &DCI,
5485654868
const X86Subtarget &Subtarget) {
5485754869
MVT VT = N->getSimpleValueType(0);
5485854870
SDLoc DL(N);
@@ -54864,6 +54876,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
5486454876
return DAG.getConstant(0, DL, VT);
5486554877
}
5486654878

54879+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54880+
if (TLI.SimplifyDemandedBits(
54881+
SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
54882+
return SDValue(N, 0);
54883+
5486754884
return SDValue();
5486854885
}
5486954886

@@ -56587,7 +56604,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
5658756604
case ISD::MGATHER:
5658856605
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
5658956606
case X86ISD::PCMPEQ:
56590-
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
56607+
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, DCI, Subtarget);
5659156608
case X86ISD::PMULDQ:
5659256609
case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI, Subtarget);
5659356610
case X86ISD::VPMADDUBSW:

llvm/test/CodeGen/X86/psadbw.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
7070
;
7171
; AVX2-LABEL: combine_psadbw_cmp_knownbits:
7272
; AVX2: # %bb.0:
73-
; AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
74-
; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
75-
; AVX2-NEXT: vpsadbw %xmm1, %xmm0, %xmm0
76-
; AVX2-NEXT: vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
73+
; AVX2-NEXT: vxorps %xmm0, %xmm0, %xmm0
7774
; AVX2-NEXT: retq
7875
%mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
7976
%sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)

0 commit comments

Comments
 (0)