-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[KnownBits] Make nuw and nsw support in computeForAddSub optimal #83382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-llvm-globalisel Author: None (goldsteinn) Changes
Patch is 35.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83382.diff 17 Files Affected:
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index fb034e0b9e3baf..4e9eb0c10a5628 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -329,8 +329,8 @@ struct KnownBits {
const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);
/// Compute known bits resulting from adding LHS and RHS.
- static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
- KnownBits RHS);
+ static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
+ const KnownBits &LHS, KnownBits RHS);
/// Compute known bits results from subtracting RHS from LHS with 1-bit
/// Borrow.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e591ac504e9f05..c220674c5f21d2 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,18 +350,19 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
}
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
- bool NSW, const APInt &DemandedElts,
+ bool NSW, bool NUW,
+ const APInt &DemandedElts,
KnownBits &KnownOut, KnownBits &Known2,
unsigned Depth, const SimplifyQuery &Q) {
computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q);
// If one operand is unknown and we have no nowrap information,
// the result will be unknown independently of the second operand.
- if (KnownOut.isUnknown() && !NSW)
+ if (KnownOut.isUnknown() && !NSW && !NUW)
return;
computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
- KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut);
+ KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
}
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
@@ -1145,13 +1146,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Instruction::Sub: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
- computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
DemandedElts, Known, Known2, Depth, Q);
break;
}
case Instruction::Add: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
- computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
DemandedElts, Known, Known2, Depth, Q);
break;
}
@@ -1245,12 +1248,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
// Note that inbounds does *not* guarantee nsw for the addition, as only
// the offset is signed, while the base address is unsigned.
Known = KnownBits::computeForAddSub(
- /*Add=*/true, /*NSW=*/false, Known, IndexBits);
+ /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, IndexBits);
}
if (!Known.isUnknown() && !AccConstIndices.isZero()) {
KnownBits Index = KnownBits::makeConstant(AccConstIndices);
Known = KnownBits::computeForAddSub(
- /*Add=*/true, /*NSW=*/false, Known, Index);
+ /*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, Index);
}
break;
}
@@ -1689,15 +1692,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
default: break;
case Intrinsic::uadd_with_overflow:
case Intrinsic::sadd_with_overflow:
- computeKnownBitsAddSub(true, II->getArgOperand(0),
- II->getArgOperand(1), false, DemandedElts,
- Known, Known2, Depth, Q);
+ computeKnownBitsAddSub(
+ true, II->getArgOperand(0), II->getArgOperand(1), /*NSW*/ false,
+ /* NUW*/ false, DemandedElts, Known, Known2, Depth, Q);
break;
case Intrinsic::usub_with_overflow:
case Intrinsic::ssub_with_overflow:
- computeKnownBitsAddSub(false, II->getArgOperand(0),
- II->getArgOperand(1), false, DemandedElts,
- Known, Known2, Depth, Q);
+ computeKnownBitsAddSub(
+ false, II->getArgOperand(0), II->getArgOperand(1), /*NSW*/ false,
+ /* NUW*/ false, DemandedElts, Known, Known2, Depth, Q);
break;
case Intrinsic::umul_with_overflow:
case Intrinsic::smul_with_overflow:
@@ -2318,7 +2321,11 @@ static bool isNonZeroRecurrence(const PHINode *PN) {
static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q, unsigned BitWidth, Value *X,
- Value *Y, bool NSW) {
+ Value *Y, bool NSW, bool NUW) {
+ if (NUW)
+ return isKnownNonZero(Y, DemandedElts, Depth, Q) ||
+ isKnownNonZero(X, DemandedElts, Depth, Q);
+
KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);
@@ -2351,7 +2358,7 @@ static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
return true;
- return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown)
+ return KnownBits::computeForAddSub(/*Add*/ true, NSW, NUW, XKnown, YKnown)
.isNonZero();
}
@@ -2556,12 +2563,9 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// If Add has nuw wrap flag, then if either X or Y is non-zero the result is
// non-zero.
auto *BO = cast<OverflowingBinaryOperator>(I);
- if (Q.IIQ.hasNoUnsignedWrap(BO))
- return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) ||
- isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);
-
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
- I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO));
+ I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
+ Q.IIQ.hasNoUnsignedWrap(BO));
}
case Instruction::Mul: {
// If X and Y are non-zero then so is X * Y as long as the multiplication
@@ -2716,7 +2720,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Intrinsic::sadd_sat:
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
II->getArgOperand(0), II->getArgOperand(1),
- /*NSW*/ true);
+ /*NSW*/ true, /* NUW*/ false);
case Intrinsic::umax:
case Intrinsic::uadd_sat:
return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) ||
diff --git a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
index ea8c20cdcd45d6..83c04612d2d43e 100644
--- a/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
@@ -269,8 +269,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
Depth + 1);
- Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
- Known2);
+ Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false,
+ /* NUW*/ false, Known, Known2);
break;
}
case TargetOpcode::G_XOR: {
@@ -296,8 +296,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
Depth + 1);
- Known =
- KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
+ Known = KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false,
+ /* NUW*/ false, Known, Known2);
break;
}
case TargetOpcode::G_AND: {
@@ -564,7 +564,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
// right.
KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
KnownBits ShiftKnown = KnownBits::computeForAddSub(
- /*Add*/ false, /*NSW*/ false, ExtKnown, WidthKnown);
+ /*Add*/ false, /*NSW*/ false, /* NUW*/ false, ExtKnown, WidthKnown);
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e150f27240d7f0..dbcdb722b741a7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3753,8 +3753,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
SDNodeFlags Flags = Op.getNode()->getFlags();
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
- Flags.hasNoSignedWrap(), Known, Known2);
+ Known = KnownBits::computeForAddSub(
+ Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+ Flags.hasNoUnsignedWrap(), Known, Known2);
break;
}
case ISD::USUBO:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 6970b230837fb9..a639cba5e35a80 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2876,9 +2876,9 @@ bool TargetLowering::SimplifyDemandedBits(
if (Op.getOpcode() == ISD::MUL) {
Known = KnownBits::mul(KnownOp0, KnownOp1);
} else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
- Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
- Flags.hasNoSignedWrap(), KnownOp0,
- KnownOp1);
+ Known = KnownBits::computeForAddSub(
+ Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
+ Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
}
break;
}
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 770e4051ca3ffa..b575f97094891f 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -54,7 +54,7 @@ KnownBits KnownBits::computeForAddCarry(
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
}
-KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
+KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
const KnownBits &LHS, KnownBits RHS) {
KnownBits KnownOut;
if (Add) {
@@ -63,23 +63,173 @@ KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
} else {
// Sum = LHS + ~RHS + 1
- std::swap(RHS.Zero, RHS.One);
- KnownOut = ::computeForAddCarry(
- LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
+ KnownBits NotRHS = RHS;
+ std::swap(NotRHS.Zero, NotRHS.One);
+ KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
+ /*CarryOne*/ true);
+ }
+ if (!NSW && !NUW)
+ return KnownOut;
+
+ // We truncate out the signbit during nsw handling so just handle this special
+ // case to avoid dealing with it later.
+ if (LHS.getBitWidth() == 1) {
+ return LHS | RHS;
}
- // Are we still trying to solve for the sign bit?
- if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
+ APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
+
+ if (ForNSW) {
+ LVal = LVal.trunc(LVal.getBitWidth() - 1);
+ RVal = RVal.trunc(RVal.getBitWidth() - 1);
+ }
+ APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
+ if (ForNSW)
+ Res = Res.sext(Res.getBitWidth() + 1);
+ return Res;
+ };
+
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax*/ true, L, R, OV);
+ };
+
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
+ const KnownBits &R, bool &OV) {
+ return GetMinMaxVal(ForNSW, /*ForMax*/ false, L, R, OV);
+ };
+
+ std::optional<bool> Negative;
+ bool Poison = false;
+ // Handle add/sub given nsw and/or nuw.
+ //
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
+ // They could probably be compressed into a single implementation of roughly
+ // half the total LOC. Leaving seperate for now to increase clarity.
+ // NB: We handle NSW by truncating sign bits then deducing bits based on
+ // the known sign result.
+ if (Add) {
if (NSW) {
- // Adding two non-negative numbers, or subtracting a negative number from
- // a non-negative one, can't wrap into negative.
- if (LHS.isNonNegative() && RHS.isNonNegative())
- KnownOut.makeNonNegative();
- // Adding two negative numbers, or subtracting a non-negative number from
- // a negative one, can't wrap into non-negative.
- else if (LHS.isNegative() && RHS.isNegative())
- KnownOut.makeNegative();
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+
+ if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
+ // (add nuw) or (add nsw PosX, PosY)
+
+ // None of the adds can end up overflowing, so min consecutive highbits
+ // in minimum possible of X + Y must all remain set.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+
+ // NSW and Positive arguments leads to positive result.
+ if (LHS.isNonNegative() && RHS.isNonNegative())
+ Negative = false;
+ else
+ KnownOut.One.clearSignBit();
+
+ Poison = OverflowMin;
+ } else if (LHS.isNegative() && RHS.isNegative()) {
+ // (add nsw NegX, NegY)
+
+ // We need to re-overflow the signbit, so we are looking for sequence of
+ // 0s from consecutive overflows.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ Negative = true;
+ Poison = !OverflowMax;
+ } else if (LHS.isNonNegative() || RHS.isNonNegative()) {
+ // (add nsw PosX, ?Y)
+
+ // If the minimal possible of X + Y overflows the signbit, then Y must
+ // have been signed (which will cause unsigned overflow otherwise nsw
+ // will be violated) leading to unsigned result.
+ if (OverflowMin)
+ Negative = false;
+ } else if (LHS.isNegative() || RHS.isNegative()) {
+ // (add nsw NegX, ?Y)
+
+ // If the maximum possible of X + Y doesn't overflows the signbit, then
+ // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
+ // overflowing the signbit results in Negative.
+ if (!OverflowMax)
+ Negative = true;
+ }
}
+ if (NUW) {
+ // (add nuw X, Y)
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
+ // Same as (add nsw PosX, PosY), basically since we can't overflow, the
+ // high bits of minimum possible X + Y must remain set.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+ Poison = OverflowMin;
+ }
+ } else {
+ if (NSW) {
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ true, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ true, LHS, RHS, OverflowMin);
+ if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
+ // (sub nuw) or (sub nsw NegX, PosY)
+
+ // None of the subs can overflow at any point, so any common high bits
+ // will subtract away and result in zeros.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ if (LHS.isNegative() && RHS.isNonNegative())
+ Negative = true;
+ else
+ KnownOut.Zero.clearSignBit();
+
+ Poison = OverflowMax;
+ } else if (LHS.isNonNegative() && RHS.isNegative()) {
+ // (sub nsw PosX, NegY)
+ Negative = false;
+
+ // Opposite case of above, we must "re-overflow" the signbit, so minimal
+ // set of high bits will be fixed.
+ KnownOut.One.setHighBits(MinVal.countLeadingOnes());
+ Poison = !OverflowMin;
+ } else if (LHS.isNegative() || RHS.isNonNegative()) {
+ // (sub nsw NegX/?X, ?Y/PosY)
+ if (OverflowMax)
+ Negative = true;
+ } else if (LHS.isNonNegative() || RHS.isNegative()) {
+ // (sub nsw PosX/?X, ?Y/NegY)
+ if (!OverflowMin)
+ Negative = false;
+ }
+ }
+ if (NUW) {
+ // (sub nuw X, Y)
+ bool OverflowMax, OverflowMin;
+ APInt MaxVal = GetMaxVal(/*ForNSW*/ false, LHS, RHS, OverflowMax);
+ APInt MinVal = GetMinVal(/*ForNSW*/ false, LHS, RHS, OverflowMin);
+
+ // Basically all common high bits between X/Y will cancel out as leading
+ // zeros.
+ KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
+ Poison = OverflowMax;
+ }
+ }
+
+ // Handle any proven sign bit.
+ if (Negative.has_value()) {
+ KnownOut.One.clearSignBit();
+ KnownOut.Zero.clearSignBit();
+
+ if (*Negative)
+ KnownOut.makeNegative();
+ else
+ KnownOut.makeNonNegative();
+ }
+
+ // Just return 0 if the nsw/nuw is violated and we have poison.
+ if (Poison || KnownOut.hasConflict()) {
+ KnownOut.setAllZero();
+ return KnownOut;
}
return KnownOut;
@@ -443,7 +593,7 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
Tmp.One.setBit(countMinTrailingZeros());
KnownAbs = computeForAddSub(
- /*Add*/ false, IntMinIsPoison,
+ /*Add*/ false, IntMinIsPoison, /*NUW*/ false,
KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);
// One more special case for IntMinIsPoison. If we don't know any ones other
@@ -489,7 +639,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
// We don't see NSW even for sadd/ssub as we want to check if the result has
// signed overflow.
- KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
+ KnownBits Res =
+ KnownBits::computeForAddSub(Add, /*NSW*/ false, /*NUW*/ false, LHS, RHS);
unsigned BitWidth = Res.getBitWidth();
auto SignBitKnown = [&](const KnownBits &K) {
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 4896ae8bad9ef3..1e7cd2bab04123 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -1903,7 +1903,8 @@ bool AMDGPUDAGToDAGISel::checkFlatScratchSVSSwizzleBug(
// voffset to (soffset + inst_offset).
KnownBits VKnown = CurDAG->computeKnownBits(VAddr);
KnownBits SKnown = KnownBits::computeForAddSub(
- true, false, CurDAG->computeKnownBits(SAddr),
+ /*Add*/ true, /*NSW*/ false, /*NUW*/ false,
+ CurDAG->computeKnownBits(SAddr),
KnownBits::makeConstant(APInt(32, ImmOffset)));
uint64_t VMax = VKnown.getMaxValue().getZExtValue();
uint64_t SMax = SKnown.getMaxValue().getZExtValue();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIn...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero)); | ||
EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One)); | ||
IsAdd, /*NSW*/ true, /*NUW*/ false, Known1, Known2); | ||
if (!KnownNSW.hasConflict()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to add conflict tests? And why not keep the isSubsetOf checks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to add conflict tests?
Because there are some inputs that will always violate nsw
/nuw
and yield poison i.e:
Value of: isOptimal(KnownNSW, KnownNSWComputed, {Known1, Known2})
Actual: false (Inputs = 1???00, 100001, Computed = 000000, Exact = !!!!!!)
Expected: true
We could also return a conflict in the poison cases, but because we assert(!Known.hasConflict())
in a lot of places that can potentially lead to crashes.
And why not keep the isSubsetOf checks?
isOptimal
/isCorrect
are helpers to the correctness check and have the added
benefit of printing the input/output on failure to make debugging easier.
52e6be7
to
837ca86
Compare
; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s | ||
; CHECK-NEXT: ptest p0, p1.b | ||
; CHECK-NEXT: cset w0, lo | ||
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't speak AArch64 but this looks like an unfortunate regression, at least in terms of number of instructions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, ping @sdesmalen-arm, just to make you guys aware.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the following IR:
define i64 @foo_last() {
%vscale = call i64 @llvm.vscale.i64()
%shl2 = shl nuw nsw i64 %vscale, 2
%idx = add nuw nsw i64 %shl2, -1
ret i64 %idx
}
declare i64 @llvm.vscale.i64()
With the current patch this is incorrectly simplified to:
t9: ch,glue = CopyToReg t0, Register:i64 $x0, Constant:i64<-1>
I don't think there's anything AArch64-specific going on there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That simplification looks correct to me. The only way add -1
can the nuw is if the other operand is zero, so the result is -1. This probably wasn't supposed to have a nuw flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doh :) Yes I see it now. It should have been sub nuw nsw i64 shl2, 1
to avoid any wrapping (or indeed remove the nuw flag). I'll have a look to see if we generate this pattern with nuw flag anywhere. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vscale
can never be zero though (it should have a minimum value of 1). Should it be returning poison
in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. I have ideas for simplifying it but that can wait.
This seems to add a good bit of compile-time overhead: http://llvm-compile-time-tracker.com/compare.php?from=2a67c28abe8cfde47c5058abbeb4b5ff9a393192&to=e383a7e50bf0b4303931ef119c856dcf53f3dd9e&stat=instructions%3Au |
Wow, is the implementation just that in-efficient, or do you think its changing control flow decisions in some key places? |
Pretty sure this is either the implementation being slow, or the change at https://github.com/llvm/llvm-project/pull/83382/files#diff-4cc32f30c79b4c8161eac82916c70c1d56e75b0dd7d6e56bbb76f8b16e20b32dR361 (which might result in more KnownBits calculations now -- dunno whether that case is common or not). |
Okay, Ill investigate and see if I can make the impact a bit more reasonable. |
837ca86
to
235f1f4
Compare
@nikic think compile time issue has been resolved: |
Value *Y, bool NSW, bool NUW) { | ||
if (NUW) | ||
return isKnownNonZero(Y, DemandedElts, Depth, Q) || | ||
isKnownNonZero(X, DemandedElts, Depth, Q); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separate refactor patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty NFC. The callsites where all checking this manually before handle. Since we want NUW in isAddNonZero
seemed to just make sense. But no strong feelings
04e24f0
to
99422c7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Here's an idea for simplifying the code. I don't know if it's measurably faster. I don't really have time to finish it off at the moment.
jayfoad@c854484
Thanks going to push this, ill look into your impl later. |
Your code is actually so much better, updating and reposting... |
99422c7
to
3da7dac
Compare
It is measurable faster :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one(); | ||
// If we have NSW as well, we also no we can't overflow the signbit so | ||
// can start counting from 1 bit back. | ||
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the part of my refactoring that I was least happy about: for the nsw+nuw case we run this line AND line 94 AND all the nsw logic below, and they all seem to be necessary. I'm sure there's a simpler way of handling the nsw+nuw case but I couldn't find it.
Nice! How does that compare to main? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. But please file an issue to track the SVE regression.
llvm/lib/Support/KnownBits.cpp
Outdated
// in minimum possible of X + Y must all remain set. | ||
if (NSW) { | ||
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one(); | ||
// If we have NSW as well, we also no we can't overflow the signbit so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// If we have NSW as well, we also no we can't overflow the signbit so | |
// If we have NSW as well, we also know we can't overflow the signbit so |
llvm/lib/Support/KnownBits.cpp
Outdated
// None of the subs can overflow at any point, so any common high bits | ||
// will subtract away and result in zeros. | ||
if (NSW) { | ||
// If we have NSW as well, we also no we can't overflow the signbit so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// If we have NSW as well, we also no we can't overflow the signbit so | |
// If we have NSW as well, we also know we can't overflow the signbit so |
llvm/lib/Support/KnownBits.cpp
Outdated
// can start counting from 1 bit back. | ||
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1); | ||
} | ||
KnownOut.One.setHighBits(MinVal.countLeadingOnes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an odd mix of countl_one and countLeadingOnes directly next to each other here.
EXPECT_TRUE(KnownNSWComputed.Zero.isSubsetOf(KnownNSW.Zero)); | ||
EXPECT_TRUE(KnownNSWComputed.One.isSubsetOf(KnownNSW.One)); | ||
IsAdd, /*NSW=*/true, /*NUW=*/false, Known1, Known2); | ||
if (!KnownNSW.hasConflict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this still happen? It looks like computeForAddSub now resets on conflict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, KnownNSW.hasConflict() will be true in the case where all admissible LHS/RHS values violated the nsw constraint. You're right that in this case, ComputedNSW will have been reset.
I bet it has to do with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No more comments, cheers - LGTM
Just some improvements that should hopefully strengthen analysis. Closes llvm#83580
3da7dac
to
d5a3bd6
Compare
Done, See: #84046 |
Closed with: 17162b6 (messed up closed tag). |
if (NUW) { | ||
if (Add) { | ||
// (add nuw X, Y) | ||
APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From testing, it seems that this can be a normal (non-saturating) +
and the usub_sat
below can be -
. But in the NSW code, the sadd_sat
/ssub_sat
are required. I don't understand why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think its b.c this is minval, does replacing the sat
in the NSW for minval only work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh but its signed min. You can replace MaxVal = LHS.getSignedMaxValue() + RHS.getSignedMaxValue();
below.
nuw
flag incomputeForAddSub
; NFCnuw
andnsw
support incomputeForAddSub
optimal