Skip to content

[KnownBits] Make nuw and nsw support in computeForAddSub optimal #83382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions llvm/include/llvm/Support/KnownBits.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ struct KnownBits {
/// Returns true if we don't know any bits.
bool isUnknown() const { return Zero.isZero() && One.isZero(); }

/// Returns true if we don't know the sign bit.
bool isSignUnknown() const {
return !Zero.isSignBitSet() && !One.isSignBitSet();
}

/// Resets the known state of all bits.
void resetAll() {
Zero.clearAllBits();
Expand Down Expand Up @@ -329,8 +334,8 @@ struct KnownBits {
const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry);

/// Compute known bits resulting from adding LHS and RHS.
static KnownBits computeForAddSub(bool Add, bool NSW, const KnownBits &LHS,
KnownBits RHS);
static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
const KnownBits &LHS, const KnownBits &RHS);

/// Compute known bits results from subtracting RHS from LHS with 1-bit
/// Borrow.
Expand Down
46 changes: 25 additions & 21 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,19 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
}

static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
bool NSW, const APInt &DemandedElts,
bool NSW, bool NUW,
const APInt &DemandedElts,
KnownBits &KnownOut, KnownBits &Known2,
unsigned Depth, const SimplifyQuery &Q) {
computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q);

// If one operand is unknown and we have no nowrap information,
// the result will be unknown independently of the second operand.
if (KnownOut.isUnknown() && !NSW)
if (KnownOut.isUnknown() && !NSW && !NUW)
return;

computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q);
KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut);
KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
}

static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Expand Down Expand Up @@ -1145,13 +1146,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Instruction::Sub: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
DemandedElts, Known, Known2, Depth, Q);
break;
}
case Instruction::Add: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
DemandedElts, Known, Known2, Depth, Q);
break;
}
Expand Down Expand Up @@ -1245,12 +1248,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
// Note that inbounds does *not* guarantee nsw for the addition, as only
// the offset is signed, while the base address is unsigned.
Known = KnownBits::computeForAddSub(
/*Add=*/true, /*NSW=*/false, Known, IndexBits);
/*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, IndexBits);
}
if (!Known.isUnknown() && !AccConstIndices.isZero()) {
KnownBits Index = KnownBits::makeConstant(AccConstIndices);
Known = KnownBits::computeForAddSub(
/*Add=*/true, /*NSW=*/false, Known, Index);
/*Add=*/true, /*NSW=*/false, /* NUW=*/false, Known, Index);
}
break;
}
Expand Down Expand Up @@ -1689,15 +1692,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
default: break;
case Intrinsic::uadd_with_overflow:
case Intrinsic::sadd_with_overflow:
computeKnownBitsAddSub(true, II->getArgOperand(0),
II->getArgOperand(1), false, DemandedElts,
Known, Known2, Depth, Q);
computeKnownBitsAddSub(
true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
/* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
break;
case Intrinsic::usub_with_overflow:
case Intrinsic::ssub_with_overflow:
computeKnownBitsAddSub(false, II->getArgOperand(0),
II->getArgOperand(1), false, DemandedElts,
Known, Known2, Depth, Q);
computeKnownBitsAddSub(
false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
/* NUW=*/false, DemandedElts, Known, Known2, Depth, Q);
break;
case Intrinsic::umul_with_overflow:
case Intrinsic::smul_with_overflow:
Expand Down Expand Up @@ -2318,7 +2321,11 @@ static bool isNonZeroRecurrence(const PHINode *PN) {

static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q, unsigned BitWidth, Value *X,
Value *Y, bool NSW) {
Value *Y, bool NSW, bool NUW) {
if (NUW)
return isKnownNonZero(Y, DemandedElts, Depth, Q) ||
isKnownNonZero(X, DemandedElts, Depth, Q);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate refactor patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty NFC. The callsites where all checking this manually before handle. Since we want NUW in isAddNonZero seemed to just make sense. But no strong feelings


KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q);
KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q);

Expand Down Expand Up @@ -2351,7 +2358,7 @@ static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q))
return true;

return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown)
return KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, XKnown, YKnown)
.isNonZero();
}

Expand Down Expand Up @@ -2556,12 +2563,9 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// If Add has nuw wrap flag, then if either X or Y is non-zero the result is
// non-zero.
auto *BO = cast<OverflowingBinaryOperator>(I);
if (Q.IIQ.hasNoUnsignedWrap(BO))
return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) ||
isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q);

return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0),
I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO));
I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO),
Q.IIQ.hasNoUnsignedWrap(BO));
}
case Instruction::Mul: {
// If X and Y are non-zero then so is X * Y as long as the multiplication
Expand Down Expand Up @@ -2716,7 +2720,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Intrinsic::sadd_sat:
return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth,
II->getArgOperand(0), II->getArgOperand(1),
/*NSW*/ true);
/*NSW=*/true, /* NUW=*/false);
case Intrinsic::umax:
case Intrinsic::uadd_sat:
return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) ||
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
Depth + 1);
Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
Known2);
Known = KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
/* NUW=*/false, Known, Known2);
break;
}
case TargetOpcode::G_XOR: {
Expand All @@ -296,8 +296,8 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
Depth + 1);
computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
Depth + 1);
Known =
KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
/* NUW=*/false, Known, Known2);
break;
}
case TargetOpcode::G_AND: {
Expand Down Expand Up @@ -564,7 +564,7 @@ void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
// right.
KnownBits ExtKnown = KnownBits::makeConstant(APInt(BitWidth, BitWidth));
KnownBits ShiftKnown = KnownBits::computeForAddSub(
/*Add*/ false, /*NSW*/ false, ExtKnown, WidthKnown);
/*Add=*/false, /*NSW=*/false, /* NUW=*/false, ExtKnown, WidthKnown);
Known = KnownBits::ashr(KnownBits::shl(Known, ShiftKnown), ShiftKnown);
break;
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3763,8 +3763,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
SDNodeFlags Flags = Op.getNode()->getFlags();
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
Flags.hasNoSignedWrap(), Known, Known2);
Known = KnownBits::computeForAddSub(
Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
Flags.hasNoUnsignedWrap(), Known, Known2);
break;
}
case ISD::USUBO:
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2876,9 +2876,9 @@ bool TargetLowering::SimplifyDemandedBits(
if (Op.getOpcode() == ISD::MUL) {
Known = KnownBits::mul(KnownOp0, KnownOp1);
} else { // Op.getOpcode() is either ISD::ADD or ISD::SUB.
Known = KnownBits::computeForAddSub(Op.getOpcode() == ISD::ADD,
Flags.hasNoSignedWrap(), KnownOp0,
KnownOp1);
Known = KnownBits::computeForAddSub(
Op.getOpcode() == ISD::ADD, Flags.hasNoSignedWrap(),
Flags.hasNoUnsignedWrap(), KnownOp0, KnownOp1);
}
break;
}
Expand Down
115 changes: 87 additions & 28 deletions llvm/lib/Support/KnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,89 @@ KnownBits KnownBits::computeForAddCarry(
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
}

KnownBits KnownBits::computeForAddSub(bool Add, bool NSW,
const KnownBits &LHS, KnownBits RHS) {
KnownBits KnownOut;
if (Add) {
// Sum = LHS + RHS + 0
KnownOut = ::computeForAddCarry(
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
} else {
// Sum = LHS + ~RHS + 1
std::swap(RHS.Zero, RHS.One);
KnownOut = ::computeForAddCarry(
LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
const KnownBits &LHS,
const KnownBits &RHS) {
unsigned BitWidth = LHS.getBitWidth();
KnownBits KnownOut(BitWidth);
// This can be a relatively expensive helper, so optimistically save some
// work.
if (LHS.isUnknown() && RHS.isUnknown())
return KnownOut;

if (!LHS.isUnknown() && !RHS.isUnknown()) {
if (Add) {
// Sum = LHS + RHS + 0
KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
/*CarryOne=*/false);
} else {
// Sum = LHS + ~RHS + 1
KnownBits NotRHS = RHS;
std::swap(NotRHS.Zero, NotRHS.One);
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
/*CarryOne=*/true);
}
}

// Are we still trying to solve for the sign bit?
if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
if (NSW) {
// Adding two non-negative numbers, or subtracting a negative number from
// a non-negative one, can't wrap into negative.
if (LHS.isNonNegative() && RHS.isNonNegative())
KnownOut.makeNonNegative();
// Adding two negative numbers, or subtracting a non-negative number from
// a negative one, can't wrap into non-negative.
else if (LHS.isNegative() && RHS.isNegative())
KnownOut.makeNegative();
// Handle add/sub given nsw and/or nuw.
if (NUW) {
if (Add) {
// (add nuw X, Y)
APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From testing, it seems that this can be a normal (non-saturating) + and the usub_sat below can be -. But in the NSW code, the sadd_sat/ssub_sat are required. I don't understand why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think its b.c this is minval, does replacing the sat in the NSW for minval only work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but its signed min. You can replace MaxVal = LHS.getSignedMaxValue() + RHS.getSignedMaxValue(); below.

// None of the adds can end up overflowing, so min consecutive highbits
// in minimum possible of X + Y must all remain set.
if (NSW) {
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
// If we have NSW as well, we also know we can't overflow the signbit so
// can start counting from 1 bit back.
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the part of my refactoring that I was least happy about: for the nsw+nuw case we run this line AND line 94 AND all the nsw logic below, and they all seem to be necessary. I'm sure there's a simpler way of handling the nsw+nuw case but I couldn't find it.

}
KnownOut.One.setHighBits(MinVal.countl_one());
} else {
// (sub nuw X, Y)
APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
// None of the subs can overflow at any point, so any common high bits
// will subtract away and result in zeros.
if (NSW) {
// If we have NSW as well, we also know we can't overflow the signbit so
// can start counting from 1 bit back.
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
}
KnownOut.Zero.setHighBits(MaxVal.countl_zero());
}
}

if (NSW) {
APInt MinVal;
APInt MaxVal;
if (Add) {
// (add nsw X, Y)
MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
} else {
// (sub nsw X, Y)
MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
}
if (MinVal.isNonNegative()) {
// If min is non-negative, result will always be non-neg (can't overflow
// around).
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
KnownOut.Zero.setSignBit();
}
if (MaxVal.isNegative()) {
// If max is negative, result will always be neg (can't overflow around).
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
KnownOut.One.setSignBit();
}
}

// Just return 0 if the nsw/nuw is violated and we have poison.
if (KnownOut.hasConflict())
KnownOut.setAllZero();
return KnownOut;
}

Expand Down Expand Up @@ -180,11 +235,14 @@ KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
// absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
KnownBits UMaxValue = umax(LHS, RHS);
KnownBits UMinValue = umin(LHS, RHS);
KnownBits MinMaxDiff = computeForAddSub(false, false, UMaxValue, UMinValue);
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
/*NUW=*/true, UMaxValue, UMinValue);

// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
KnownBits Diff0 = computeForAddSub(false, false, LHS, RHS);
KnownBits Diff1 = computeForAddSub(false, false, RHS, LHS);
KnownBits Diff0 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
KnownBits Diff1 =
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
KnownBits SubDiff = Diff0.intersectWith(Diff1);

KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
Expand Down Expand Up @@ -459,7 +517,7 @@ KnownBits KnownBits::abs(bool IntMinIsPoison) const {
Tmp.One.setBit(countMinTrailingZeros());

KnownAbs = computeForAddSub(
/*Add*/ false, IntMinIsPoison,
/*Add*/ false, IntMinIsPoison, /*NUW=*/false,
KnownBits::makeConstant(APInt(getBitWidth(), 0)), Tmp);

// One more special case for IntMinIsPoison. If we don't know any ones other
Expand Down Expand Up @@ -505,7 +563,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
assert(!LHS.hasConflict() && !RHS.hasConflict() && "Bad inputs");
// We don't see NSW even for sadd/ssub as we want to check if the result has
// signed overflow.
KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW*/ false, LHS, RHS);
KnownBits Res =
KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
unsigned BitWidth = Res.getBitWidth();
auto SignBitKnown = [&](const KnownBits &K) {
return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1903,7 +1903,8 @@ bool AMDGPUDAGToDAGISel::checkFlatScratchSVSSwizzleBug(
// voffset to (soffset + inst_offset).
KnownBits VKnown = CurDAG->computeKnownBits(VAddr);
KnownBits SKnown = KnownBits::computeForAddSub(
true, false, CurDAG->computeKnownBits(SAddr),
/*Add=*/true, /*NSW=*/false, /*NUW=*/false,
CurDAG->computeKnownBits(SAddr),
KnownBits::makeConstant(APInt(32, ImmOffset)));
uint64_t VMax = VKnown.getMaxValue().getZExtValue();
uint64_t SMax = SKnown.getMaxValue().getZExtValue();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4573,7 +4573,7 @@ bool AMDGPUInstructionSelector::checkFlatScratchSVSSwizzleBug(
// voffset to (soffset + inst_offset).
auto VKnown = KB->getKnownBits(VAddr);
auto SKnown = KnownBits::computeForAddSub(
true, false, KB->getKnownBits(SAddr),
/*Add=*/true, /*NSW=*/false, /*NUW=*/false, KB->getKnownBits(SAddr),
KnownBits::makeConstant(APInt(32, ImmOffset)));
uint64_t VMax = VKnown.getMaxValue().getZExtValue();
uint64_t SMax = SKnown.getMaxValue().getZExtValue();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20154,7 +20154,8 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
// CSNEG: KnownOp0 or KnownOp1 * -1
if (Op.getOpcode() == ARMISD::CSINC)
KnownOp1 = KnownBits::computeForAddSub(
true, false, KnownOp1, KnownBits::makeConstant(APInt(32, 1)));
/*Add=*/true, /*NSW=*/false, /*NUW=*/false, KnownOp1,
KnownBits::makeConstant(APInt(32, 1)));
else if (Op.getOpcode() == ARMISD::CSINV)
std::swap(KnownOp1.Zero, KnownOp1.One);
else if (Op.getOpcode() == ARMISD::CSNEG)
Expand Down
Loading