@@ -36836,12 +36836,24 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
36836
36836
break;
36837
36837
}
36838
36838
case X86ISD::PSADBW: {
36839
+ SDValue LHS = Op.getOperand(0);
36840
+ SDValue RHS = Op.getOperand(1);
36839
36841
assert(VT.getScalarType() == MVT::i64 &&
36840
- Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
36842
+ LHS.getValueType() == RHS.getValueType() &&
36843
+ LHS.getValueType().getScalarType() == MVT::i8 &&
36841
36844
"Unexpected PSADBW types");
36842
36845
36843
- // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
36844
- Known.Zero.setBitsFrom(16);
36846
+ KnownBits Known2;
36847
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
36848
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
36849
+ Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
36850
+ Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
36851
+ Known = KnownBits::absdiff(Known, Known2).zext(16);
36852
+ // Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
36853
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
36854
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
36855
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
36856
+ Known = Known.zext(64);
36845
36857
break;
36846
36858
}
36847
36859
case X86ISD::PCMPGT:
@@ -54853,6 +54865,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
54853
54865
}
54854
54866
54855
54867
static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
54868
+ TargetLowering::DAGCombinerInfo &DCI,
54856
54869
const X86Subtarget &Subtarget) {
54857
54870
MVT VT = N->getSimpleValueType(0);
54858
54871
SDLoc DL(N);
@@ -54864,6 +54877,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
54864
54877
return DAG.getConstant(0, DL, VT);
54865
54878
}
54866
54879
54880
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
54881
+ if (TLI.SimplifyDemandedBits(
54882
+ SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
54883
+ return SDValue(N, 0);
54884
+
54867
54885
return SDValue();
54868
54886
}
54869
54887
@@ -56587,7 +56605,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
56587
56605
case ISD::MGATHER:
56588
56606
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
56589
56607
case X86ISD::PCMPEQ:
56590
- case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
56608
+ case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, DCI, Subtarget);
56591
56609
case X86ISD::PMULDQ:
56592
56610
case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI, Subtarget);
56593
56611
case X86ISD::VPMADDUBSW:
0 commit comments