-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[KnownBits] Make nuw and nsw support in computeForAddSub optimal #83382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -350,18 +350,19 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, | |
} | ||
|
||
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, | ||
bool NSW, const APInt &DemandedElts, | ||
bool NSW, bool NUW, | ||
const APInt &DemandedElts, | ||
KnownBits &KnownOut, KnownBits &Known2, | ||
unsigned Depth, const SimplifyQuery &Q) { | ||
computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q); | ||
|
||
// If one operand is unknown and we have no nowrap information, | ||
// the result will be unknown independently of the second operand. | ||
if (KnownOut.isUnknown() && !NSW) | ||
if (KnownOut.isUnknown() && !NSW && !NUW) | ||
return; | ||
|
||
computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); | ||
KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut); | ||
KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut); | ||
} | ||
|
||
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, | ||
|
@@ -1145,13 +1146,15 @@ static void computeKnownBitsFromOperator(const Operator *I, | |
} | ||
case Instruction::Sub: { | ||
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); | ||
computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, | ||
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I)); | ||
computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW, | ||
DemandedElts, Known, Known2, Depth, Q); | ||
break; | ||
} | ||
case Instruction::Add: { | ||
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); | ||
computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, | ||
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I)); | ||
computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW, | ||
DemandedElts, Known, Known2, Depth, Q); | ||
break; | ||
} | ||
|
@@ -1245,12 +1248,12 @@ static void computeKnownBitsFromOperator(const Operator *I, | |
// Note that inbounds does *not* guarantee nsw for the addition, as only | ||
// the offset is signed, while the base address is unsigned. | ||
Known = KnownBits::computeForAddSub( | ||
/*Add=*/true, /*NSW=*/false, Known, IndexBits); | ||
/*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, IndexBits); | ||
} | ||
if (!Known.isUnknown() && !AccConstIndices.isZero()) { | ||
KnownBits Index = KnownBits::makeConstant(AccConstIndices); | ||
Known = KnownBits::computeForAddSub( | ||
/*Add=*/true, /*NSW=*/false, Known, Index); | ||
/*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, Index); | ||
} | ||
break; | ||
} | ||
|
@@ -1689,15 +1692,15 @@ static void computeKnownBitsFromOperator(const Operator *I, | |
default: break; | ||
case Intrinsic::uadd_with_overflow: | ||
case Intrinsic::sadd_with_overflow: | ||
computeKnownBitsAddSub(true, II->getArgOperand(0), | ||
II->getArgOperand(1), false, DemandedElts, | ||
Known, Known2, Depth, Q); | ||
computeKnownBitsAddSub( | ||
true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false, | ||
/* NUW=*/false, DemandedElts, Known, Known2, Depth, Q); | ||
break; | ||
case Intrinsic::usub_with_overflow: | ||
case Intrinsic::ssub_with_overflow: | ||
computeKnownBitsAddSub(false, II->getArgOperand(0), | ||
II->getArgOperand(1), false, DemandedElts, | ||
Known, Known2, Depth, Q); | ||
computeKnownBitsAddSub( | ||
false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false, | ||
/* NUW=*/false, DemandedElts, Known, Known2, Depth, Q); | ||
break; | ||
case Intrinsic::umul_with_overflow: | ||
case Intrinsic::smul_with_overflow: | ||
|
@@ -2318,7 +2321,11 @@ static bool isNonZeroRecurrence(const PHINode *PN) { | |
|
||
static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth, | ||
const SimplifyQuery &Q, unsigned BitWidth, Value *X, | ||
Value *Y, bool NSW) { | ||
Value *Y, bool NSW, bool NUW) { | ||
if (NUW) | ||
return isKnownNonZero(Y, DemandedElts, Depth, Q) || | ||
isKnownNonZero(X, DemandedElts, Depth, Q); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate refactor patch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty NFC. The callsites where all checking this manually before handle. Since we want NUW in
dtcxzyw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q); | ||
KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q); | ||
|
||
|
@@ -2351,7 +2358,7 @@ static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth, | |
isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) | ||
return true; | ||
|
||
return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown) | ||
return KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, XKnown, YKnown) | ||
.isNonZero(); | ||
} | ||
|
||
|
@@ -2556,12 +2563,9 @@ static bool isKnownNonZeroFromOperator(const Operator *I, | |
// If Add has nuw wrap flag, then if either X or Y is non-zero the result is | ||
// non-zero. | ||
auto *BO = cast<OverflowingBinaryOperator>(I); | ||
if (Q.IIQ.hasNoUnsignedWrap(BO)) | ||
return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) || | ||
isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); | ||
|
||
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0), | ||
I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO)); | ||
I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO), | ||
Q.IIQ.hasNoUnsignedWrap(BO)); | ||
} | ||
case Instruction::Mul: { | ||
// If X and Y are non-zero then so is X * Y as long as the multiplication | ||
|
@@ -2716,7 +2720,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I, | |
case Intrinsic::sadd_sat: | ||
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, | ||
II->getArgOperand(0), II->getArgOperand(1), | ||
/*NSW*/ true); | ||
/*NSW=*/true, /* NUW=*/false); | ||
case Intrinsic::umax: | ||
case Intrinsic::uadd_sat: | ||
return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) || | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,34 +54,89 @@ KnownBits KnownBits::computeForAddCarry( | |
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue()); | ||
} | ||
|
||
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, | ||
const KnownBits &LHS, KnownBits RHS) { | ||
KnownBits KnownOut; | ||
if (Add) { | ||
// Sum = LHS + RHS + 0 | ||
KnownOut = ::computeForAddCarry( | ||
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false); | ||
} else { | ||
// Sum = LHS + ~RHS + 1 | ||
std::swap(RHS.Zero, RHS.One); | ||
KnownOut = ::computeForAddCarry( | ||
LHS, RHS, /*CarryZero*/false, /*CarryOne*/true); | ||
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW, | ||
const KnownBits &LHS, | ||
const KnownBits &RHS) { | ||
unsigned BitWidth = LHS.getBitWidth(); | ||
KnownBits KnownOut(BitWidth); | ||
// This can be a relatively expensive helper, so optimistically save some | ||
// work. | ||
if (LHS.isUnknown() && RHS.isUnknown()) | ||
return KnownOut; | ||
|
||
if (!LHS.isUnknown() && !RHS.isUnknown()) { | ||
if (Add) { | ||
// Sum = LHS + RHS + 0 | ||
KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true, | ||
/*CarryOne=*/false); | ||
} else { | ||
// Sum = LHS + ~RHS + 1 | ||
KnownBits NotRHS = RHS; | ||
std::swap(NotRHS.Zero, NotRHS.One); | ||
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false, | ||
/*CarryOne=*/true); | ||
} | ||
} | ||
|
||
// Are we still trying to solve for the sign bit? | ||
if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) { | ||
if (NSW) { | ||
// Adding two non-negative numbers, or subtracting a negative number from | ||
// a non-negative one, can't wrap into negative. | ||
if (LHS.isNonNegative() && RHS.isNonNegative()) | ||
KnownOut.makeNonNegative(); | ||
// Adding two negative numbers, or subtracting a non-negative number from | ||
// a negative one, can't wrap into non-negative. | ||
else if (LHS.isNegative() && RHS.isNegative()) | ||
KnownOut.makeNegative(); | ||
// Handle add/sub given nsw and/or nuw. | ||
if (NUW) { | ||
if (Add) { | ||
// (add nuw X, Y) | ||
APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From testing, it seems that this can be a normal (non-saturating) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think its b.c this is minval, does replacing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh but its signed min. You can replace |
||
// None of the adds can end up overflowing, so min consecutive highbits | ||
// in minimum possible of X + Y must all remain set. | ||
if (NSW) { | ||
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one(); | ||
// If we have NSW as well, we also know we can't overflow the signbit so | ||
// can start counting from 1 bit back. | ||
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the part of my refactoring that I was least happy about: for the nsw+nuw case we run this line AND line 94 AND all the nsw logic below, and they all seem to be necessary. I'm sure there's a simpler way of handling the nsw+nuw case but I couldn't find it. |
||
} | ||
KnownOut.One.setHighBits(MinVal.countl_one()); | ||
} else { | ||
// (sub nuw X, Y) | ||
APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue()); | ||
// None of the subs can overflow at any point, so any common high bits | ||
// will subtract away and result in zeros. | ||
if (NSW) { | ||
// If we have NSW as well, we also know we can't overflow the signbit so | ||
// can start counting from 1 bit back. | ||
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero(); | ||
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1); | ||
} | ||
KnownOut.Zero.setHighBits(MaxVal.countl_zero()); | ||
} | ||
} | ||
|
||
if (NSW) { | ||
APInt MinVal; | ||
APInt MaxVal; | ||
if (Add) { | ||
// (add nsw X, Y) | ||
MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue()); | ||
MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue()); | ||
} else { | ||
// (sub nsw X, Y) | ||
MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue()); | ||
MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue()); | ||
} | ||
if (MinVal.isNonNegative()) { | ||
// If min is non-negative, result will always be non-neg (can't overflow | ||
// around). | ||
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one(); | ||
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1); | ||
KnownOut.Zero.setSignBit(); | ||
} | ||
if (MaxVal.isNegative()) { | ||
// If max is negative, result will always be neg (can't overflow around). | ||
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero(); | ||
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1); | ||
KnownOut.One.setSignBit(); | ||
} | ||
} | ||
|
||
// Just return 0 if the nsw/nuw is violated and we have poison. | ||
if (KnownOut.hasConflict()) | ||
KnownOut.setAllZero(); | ||
return KnownOut; | ||
} | ||
|
||
|
@@ -180,11 +235,14 @@ KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) { | |
// absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)). | ||
KnownBits UMaxValue = umax(LHS, RHS); | ||
KnownBits UMinValue = umin(LHS, RHS); | ||
KnownBits MinMaxDiff = computeForAddSub(false, false, UMaxValue, UMinValue); | ||
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false, | ||
/*NUW=*/true, UMaxValue, UMinValue); | ||
|
||
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS). | ||
KnownBits Diff0 = computeForAddSub(false, false, LHS, RHS); | ||
KnownBits Diff1 = computeForAddSub(false, false, RHS, LHS); | ||
KnownBits Diff0 = | ||
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS); | ||
KnownBits Diff1 = | ||
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS); | ||
KnownBits SubDiff = Diff0.intersectWith(Diff1); | ||
|
||
KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff); | ||
|
@@ -459,7 +517,7 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const { | |
Tmp.One.setBit(countMinTrailingZeros()); | ||
|
||
KnownAbs = computeForAddSub( | ||
/*Add*/ false, IntMinIsPoison, | ||
/*Add*/ false, IntMinIsPoison, /*NUW=*/false, | ||
KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp); | ||
|
||
// One more special case for IntMinIsPoison. If we don't know any ones other | ||
|
@@ -505,7 +563,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed, | |
assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs"); | ||
// We don't see NSW even for sadd/ssub as we want to check if the result has | ||
// signed overflow. | ||
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS); | ||
KnownBits Res = | ||
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS); | ||
unsigned BitWidth = Res.getBitWidth(); | ||
auto SignBitKnown = [&](const KnownBits &K) { | ||
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; | ||
|
Uh oh!
There was an error while loading. Please reload this page.