Skip to content

[X86] Improve KnownBits for X86ISD::PSADBW nodes #83830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Mar 4, 2024

Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum.

Add implementations that handle both the X86ISD::PSADBW nodes and the INTRINSIC_WO_CHAIN intrinsics (pre-legalization).

@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2024

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum


Full diff: https://github.com/llvm/llvm-project/pull/83830.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+21-4)
  • (modified) llvm/test/CodeGen/X86/psadbw.ll (+1-4)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b87e3121838dcc..5076ac5e347e9f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36836,12 +36836,23 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     break;
   }
   case X86ISD::PSADBW: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
     assert(VT.getScalarType() == MVT::i64 &&
-           Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getScalarType() == MVT::i8 &&
            "Unexpected PSADBW types");
 
-    // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
-    Known.Zero.setBitsFrom(16);
+    KnownBits Known2;
+    unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+    APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+    Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
+    Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
+    Known = KnownBits::absdiff(Known, Known2).zext(16);
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = Known.zext(64);
     break;
   }
   case X86ISD::PCMPGT:
@@ -54853,6 +54864,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
+                                    TargetLowering::DAGCombinerInfo &DCI,
                                     const X86Subtarget &Subtarget) {
   MVT VT = N->getSimpleValueType(0);
   SDLoc DL(N);
@@ -54864,6 +54876,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
       return DAG.getConstant(0, DL, VT);
   }
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (TLI.SimplifyDemandedBits(
+          SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
+    return SDValue(N, 0);
+
   return SDValue();
 }
 
@@ -56587,7 +56604,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MGATHER:
   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:
-  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
+  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, DCI, Subtarget);
   case X86ISD::PMULDQ:
   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
   case X86ISD::VPMADDUBSW:
diff --git a/llvm/test/CodeGen/X86/psadbw.ll b/llvm/test/CodeGen/X86/psadbw.ll
index 8141b22d321f4d..8044472b13e3a8 100644
--- a/llvm/test/CodeGen/X86/psadbw.ll
+++ b/llvm/test/CodeGen/X86/psadbw.ll
@@ -70,10 +70,7 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
 ;
 ; AVX2-LABEL: combine_psadbw_cmp_knownbits:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX2-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX2-NEXT:    retq
   %mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
   %sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)

@RKSimon RKSimon force-pushed the knownbits-psadbw branch from f5bb6c3 to cf8f19b Compare March 4, 2024 14:41
@RKSimon RKSimon force-pushed the knownbits-psadbw branch from 2bb14ec to 31c9bde Compare March 6, 2024 12:24
Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
@RKSimon RKSimon force-pushed the knownbits-psadbw branch from 31c9bde to 8073a10 Compare March 6, 2024 12:42
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
Known, Known);
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
Known, Known);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Known = KnownBits::shl(Known, KnownBits::makeConstant(APInt(Known.getBitWidth(), 8)), /*NSW=*/true, /*NUW=*/true);?
Does the 3x adds do a better job or something? Think our shl and add impl are both optimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than this, it all looks good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I was just trying to make it explicitly match the PSADBW expansion as possible, but I can just replace it with a shl by 3 (not 8).

Fun fact: KnownBits::shl doesn't need the shift amount to be the same bitwidth as the shift value :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would KnownBits::shl make the lower 3 bits of Known.Zero true? That would be wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, noticed that when I tried it - I used the computeForAddSub chain instead

@RKSimon RKSimon merged commit 0bd9255 into llvm:main Mar 6, 2024
@RKSimon RKSimon deleted the knownbits-psadbw branch March 6, 2024 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants