Skip to content

Commit c854484

Browse files
committed
Simplify using saturated add/sub
1 parent 99422c7 commit c854484

File tree

1 file changed

+50
-162
lines changed

1 file changed

+50
-162
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 50 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -57,181 +57,69 @@ KnownBits KnownBits::computeForAddCarry(
5757
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
5858
const KnownBits &LHS,
5959
const KnownBits &RHS) {
60-
// This can be a relatively expensive helper, so optimistically save some
61-
// work.
62-
if (LHS.isUnknown() && RHS.isUnknown())
63-
return LHS;
64-
KnownBits KnownOut;
65-
if (Add) {
66-
// Sum = LHS + RHS + 0
67-
KnownOut =
68-
::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
69-
} else {
70-
// Sum = LHS + ~RHS + 1
71-
KnownBits NotRHS = RHS;
72-
std::swap(NotRHS.Zero, NotRHS.One);
73-
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
74-
/*CarryOne*/ true);
75-
}
76-
if (!NSW && !NUW)
77-
return KnownOut;
78-
79-
auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
80-
const KnownBits &R, bool &OV) {
81-
APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
82-
APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
83-
84-
if (ForNSW) {
85-
LVal.clearSignBit();
86-
RVal.clearSignBit();
87-
}
88-
APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
89-
if (ForNSW) {
90-
OV = Res.isSignBitSet();
91-
Res.clearSignBit();
92-
if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
93-
Res.setSignBit();
60+
unsigned BitWidth = LHS.getBitWidth();
61+
KnownBits KnownOut(LHS.getBitWidth());
62+
if (!LHS.isUnknown() && !RHS.isUnknown()) {
63+
if (Add) {
64+
// Sum = LHS + RHS + 0
65+
KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
66+
/*CarryOne=*/false);
67+
} else {
68+
// Sum = LHS + ~RHS + 1
69+
KnownBits NotRHS = RHS;
70+
std::swap(NotRHS.Zero, NotRHS.One);
71+
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero=*/false,
72+
/*CarryOne=*/true);
9473
}
95-
return Res;
96-
};
97-
98-
auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99-
const KnownBits &R, bool &OV) {
100-
return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
101-
};
102-
103-
auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
104-
const KnownBits &R, bool &OV) {
105-
return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
106-
};
107-
108-
auto ForceNegative = [](KnownBits &Known) {
109-
Known.Zero.clearSignBit();
110-
Known.One.setSignBit();
111-
};
112-
113-
auto ForcePositive = [](KnownBits &Known) {
114-
Known.One.clearSignBit();
115-
Known.Zero.setSignBit();
116-
};
74+
}
11775

118-
// Handle add/sub given nsw and/or nuw.
119-
//
120-
// Possible TODO: Add/Sub implementations mirror one another in many ways.
121-
// They could probably be compressed into a single implementation of roughly
122-
// half the total LOC. Leaving seperate for now to increase clarity.
123-
// NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
124-
// deducing bits based on the known sign result.
125-
if (Add) {
126-
if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
127-
bool OverflowMin;
128-
APInt MinVal;
76+
if (NUW) {
77+
if (Add) {
78+
// (add nuw X, Y)
79+
APInt MinVal = LHS.getMinValue().uadd_sat(RHS.getMinValue());
12980
if (NSW) {
130-
MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
131-
// (add nsw nuw) or (add nsw PosX, PosY)
132-
133-
// None of the adds can end up overflowing, so min consecutive
134-
// highbits in minimum possible of X + Y must all remain set.
135-
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
136-
137-
// NSW and Positive arguments leads to positive result.
138-
if (LHS.isNonNegative() && RHS.isNonNegative())
139-
ForcePositive(KnownOut);
81+
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
82+
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
14083
}
141-
if (NUW) {
142-
KnownOut.One.clearSignBit();
143-
// (add nuw X, Y)
144-
MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
145-
// Same as (add nsw PosX, PosY), basically since we can't overflow,
146-
// the high bits of minimum possible X + Y must remain set.
147-
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
84+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
85+
} else {
86+
// (sub nuw X, Y)
87+
APInt MaxVal = LHS.getMaxValue().usub_sat(RHS.getMinValue());
88+
if (NSW) {
89+
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
90+
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
14891
}
149-
} else if (LHS.isNegative() && RHS.isNegative()) {
150-
bool OverflowMax;
151-
APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
152-
// (add nsw NegX, NegY)
153-
154-
// We need to re-overflow the signbit, so we are looking for sequence
155-
// of 0s from consecutive overflows.
15692
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
157-
ForceNegative(KnownOut);
158-
} else if (!KnownOut.isSignUnknown()) {
159-
// Pass, avoid extra work if we already know the sign bit.
160-
} else if (LHS.isNonNegative() || RHS.isNonNegative()) {
161-
bool OverflowMin;
162-
(void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
163-
// (add nsw PosX, ?Y)
164-
165-
// If the minimal possible of X + Y overflows the signbit, then Y must
166-
// have been signed (which will cause unsigned overflow otherwise nsw
167-
// will be violated) leading to unsigned result.
168-
if (OverflowMin)
169-
KnownOut.makeNonNegative();
170-
} else if (LHS.isNegative() || RHS.isNegative()) {
171-
bool OverflowMax;
172-
(void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
173-
// (add nsw NegX, ?Y)
174-
175-
// If the maximum possible of X + Y doesn't overflows the signbit,
176-
// then Y must have been unsigned (otherwise nsw violated) so NegX +
177-
// PosY w.o overflowing the signbit results in Negative.
178-
if (!OverflowMax)
179-
KnownOut.makeNegative();
18093
}
181-
} else {
182-
if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
183-
bool OverflowMax;
184-
APInt MaxVal;
185-
if (NSW) {
186-
MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
187-
// (sub nsw nuw) or (sub nsw NegX, PosY)
188-
189-
// None of the subs can overflow at any point, so any common high bits
190-
// will subtract away and result in zeros.
191-
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
192-
if (LHS.isNegative() && RHS.isNonNegative())
193-
ForceNegative(KnownOut);
194-
}
195-
if (NUW) {
196-
KnownOut.Zero.clearSignBit();
197-
// (sub nuw X, Y)
198-
MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
199-
200-
// Basically all common high bits between X/Y will cancel out as
201-
// leading zeros.
202-
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
203-
}
204-
} else if (LHS.isNonNegative() && RHS.isNegative()) {
205-
bool OverflowMin;
206-
APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
207-
// (sub nsw PosX, NegY)
94+
}
20895

209-
// Opposite case of above, we must "re-overflow" the signbit, so
210-
// minimal set of high bits will be fixed.
211-
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
212-
ForcePositive(KnownOut);
213-
} else if (!KnownOut.isSignUnknown()) {
214-
// Pass, avoid extra work if we already know the sign bit.
215-
} else if (LHS.isNegative() || RHS.isNonNegative()) {
216-
bool OverflowMax;
217-
(void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
218-
// (sub nsw NegX/?X, ?Y/PosY)
219-
if (OverflowMax)
220-
KnownOut.makeNegative();
221-
} else if (LHS.isNonNegative() || RHS.isNegative()) {
222-
bool OverflowMin;
223-
(void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
224-
// (sub nsw PosX/?X, ?Y/NegY)
225-
if (!OverflowMin)
226-
KnownOut.makeNonNegative();
96+
if (NSW) {
97+
APInt MinVal;
98+
APInt MaxVal;
99+
if (Add) {
100+
// (add nsw X, Y)
101+
MinVal = LHS.getSignedMinValue().sadd_sat(RHS.getSignedMinValue());
102+
MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS.getSignedMaxValue());
103+
} else {
104+
// (sub nsw X, Y)
105+
MinVal = LHS.getSignedMinValue().ssub_sat(RHS.getSignedMaxValue());
106+
MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS.getSignedMinValue());
107+
}
108+
if (MinVal.isNonNegative()) {
109+
unsigned NumBits = MinVal.trunc(BitWidth - 1).countl_one();
110+
KnownOut.One.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
111+
KnownOut.Zero.setSignBit();
112+
}
113+
if (MaxVal.isNegative()) {
114+
unsigned NumBits = MaxVal.trunc(BitWidth - 1).countl_zero();
115+
KnownOut.Zero.setBits(BitWidth - 1 - NumBits, BitWidth - 1);
116+
KnownOut.One.setSignBit();
227117
}
228118
}
229119

230120
// Just return 0 if the nsw/nuw is violated and we have poison.
231-
if (KnownOut.hasConflict()) {
121+
if (KnownOut.hasConflict())
232122
KnownOut.setAllZero();
233-
return KnownOut;
234-
}
235123

236124
return KnownOut;
237125
}

0 commit comments

Comments
 (0)