Skip to content

Commit 99422c7

Browse files
committed
[KnownBits] Make nuw and nsw support in computeForAddSub optimal
Just some improvements that should hopefully strengthen analysis.
1 parent aa76e97 commit 99422c7

File tree

9 files changed

+267
-62
lines changed

9 files changed

+267
-62
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ struct KnownBits {
6262
/// Returns true if we don't know any bits.
6363
bool isUnknown() const { return Zero.isZero() && One.isZero(); }
6464

65+
/// Returns true if we don't know the sign bit.
66+
bool isSignUnknown() const {
67+
return !Zero.isSignBitSet() && !One.isSignBitSet();
68+
}
69+
6570
/// Resets the known state of all bits.
6671
void resetAll() {
6772
Zero.clearAllBits();
@@ -330,7 +335,7 @@ struct KnownBits {
330335

331336
/// Compute known bits resulting from adding LHS and RHS.
332337
static KnownBits computeForAddSub(bool Add, bool NSW, bool NUW,
333-
const KnownBits &LHS, KnownBits RHS);
338+
const KnownBits &LHS, const KnownBits &RHS);
334339

335340
/// Compute known bits results from subtracting RHS from LHS with 1-bit
336341
/// Borrow.

llvm/lib/Support/KnownBits.cpp

Lines changed: 167 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,32 +54,183 @@ KnownBits KnownBits::computeForAddCarry(
5454
LHS, RHS, Carry.Zero.getBoolValue(), Carry.One.getBoolValue());
5555
}
5656

57-
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool /*NUW*/,
58-
const KnownBits &LHS, KnownBits RHS) {
57+
KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
58+
const KnownBits &LHS,
59+
const KnownBits &RHS) {
60+
// This can be a relatively expensive helper, so optimistically save some
61+
// work.
62+
if (LHS.isUnknown() && RHS.isUnknown())
63+
return LHS;
5964
KnownBits KnownOut;
6065
if (Add) {
6166
// Sum = LHS + RHS + 0
62-
KnownOut = ::computeForAddCarry(
63-
LHS, RHS, /*CarryZero*/true, /*CarryOne*/false);
67+
KnownOut =
68+
::computeForAddCarry(LHS, RHS, /*CarryZero*/ true, /*CarryOne*/ false);
6469
} else {
6570
// Sum = LHS + ~RHS + 1
66-
std::swap(RHS.Zero, RHS.One);
67-
KnownOut = ::computeForAddCarry(
68-
LHS, RHS, /*CarryZero*/false, /*CarryOne*/true);
71+
KnownBits NotRHS = RHS;
72+
std::swap(NotRHS.Zero, NotRHS.One);
73+
KnownOut = ::computeForAddCarry(LHS, NotRHS, /*CarryZero*/ false,
74+
/*CarryOne*/ true);
6975
}
76+
if (!NSW && !NUW)
77+
return KnownOut;
7078

71-
// Are we still trying to solve for the sign bit?
72-
if (!KnownOut.isNegative() && !KnownOut.isNonNegative()) {
73-
if (NSW) {
74-
// Adding two non-negative numbers, or subtracting a negative number from
75-
// a non-negative one, can't wrap into negative.
76-
if (LHS.isNonNegative() && RHS.isNonNegative())
79+
auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
80+
const KnownBits &R, bool &OV) {
81+
APInt LVal = ForMax ? L.getMaxValue() : L.getMinValue();
82+
APInt RVal = Add == ForMax ? R.getMaxValue() : R.getMinValue();
83+
84+
if (ForNSW) {
85+
LVal.clearSignBit();
86+
RVal.clearSignBit();
87+
}
88+
APInt Res = Add ? LVal.uadd_ov(RVal, OV) : LVal.usub_ov(RVal, OV);
89+
if (ForNSW) {
90+
OV = Res.isSignBitSet();
91+
Res.clearSignBit();
92+
if (Res.getBitWidth() > 1 && Res[Res.getBitWidth() - 2])
93+
Res.setSignBit();
94+
}
95+
return Res;
96+
};
97+
98+
auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99+
const KnownBits &R, bool &OV) {
100+
return GetMinMaxVal(ForNSW, /*ForMax=*/true, L, R, OV);
101+
};
102+
103+
auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
104+
const KnownBits &R, bool &OV) {
105+
return GetMinMaxVal(ForNSW, /*ForMax=*/false, L, R, OV);
106+
};
107+
108+
auto ForceNegative = [](KnownBits &Known) {
109+
Known.Zero.clearSignBit();
110+
Known.One.setSignBit();
111+
};
112+
113+
auto ForcePositive = [](KnownBits &Known) {
114+
Known.One.clearSignBit();
115+
Known.Zero.setSignBit();
116+
};
117+
118+
// Handle add/sub given nsw and/or nuw.
119+
//
120+
// Possible TODO: Add/Sub implementations mirror one another in many ways.
121+
// They could probably be compressed into a single implementation of roughly
122+
// half the total LOC. Leaving seperate for now to increase clarity.
123+
// NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
124+
// deducing bits based on the known sign result.
125+
if (Add) {
126+
if (NUW || (LHS.isNonNegative() && RHS.isNonNegative())) {
127+
bool OverflowMin;
128+
APInt MinVal;
129+
if (NSW) {
130+
MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
131+
// (add nsw nuw) or (add nsw PosX, PosY)
132+
133+
// None of the adds can end up overflowing, so min consecutive
134+
// highbits in minimum possible of X + Y must all remain set.
135+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
136+
137+
// NSW and Positive arguments leads to positive result.
138+
if (LHS.isNonNegative() && RHS.isNonNegative())
139+
ForcePositive(KnownOut);
140+
}
141+
if (NUW) {
142+
KnownOut.One.clearSignBit();
143+
// (add nuw X, Y)
144+
MinVal = GetMinVal(/*ForNSW=*/false, LHS, RHS, OverflowMin);
145+
// Same as (add nsw PosX, PosY), basically since we can't overflow,
146+
// the high bits of minimum possible X + Y must remain set.
147+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
148+
}
149+
} else if (LHS.isNegative() && RHS.isNegative()) {
150+
bool OverflowMax;
151+
APInt MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
152+
// (add nsw NegX, NegY)
153+
154+
// We need to re-overflow the signbit, so we are looking for sequence
155+
// of 0s from consecutive overflows.
156+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
157+
ForceNegative(KnownOut);
158+
} else if (!KnownOut.isSignUnknown()) {
159+
// Pass, avoid extra work if we already know the sign bit.
160+
} else if (LHS.isNonNegative() || RHS.isNonNegative()) {
161+
bool OverflowMin;
162+
(void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
163+
// (add nsw PosX, ?Y)
164+
165+
// If the minimal possible of X + Y overflows the signbit, then Y must
166+
// have been signed (which will cause unsigned overflow otherwise nsw
167+
// will be violated) leading to unsigned result.
168+
if (OverflowMin)
77169
KnownOut.makeNonNegative();
78-
// Adding two negative numbers, or subtracting a non-negative number from
79-
// a negative one, can't wrap into non-negative.
80-
else if (LHS.isNegative() && RHS.isNegative())
170+
} else if (LHS.isNegative() || RHS.isNegative()) {
171+
bool OverflowMax;
172+
(void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
173+
// (add nsw NegX, ?Y)
174+
175+
// If the maximum possible of X + Y doesn't overflows the signbit,
176+
// then Y must have been unsigned (otherwise nsw violated) so NegX +
177+
// PosY w.o overflowing the signbit results in Negative.
178+
if (!OverflowMax)
81179
KnownOut.makeNegative();
82180
}
181+
} else {
182+
if (NUW || (LHS.isNegative() && RHS.isNonNegative())) {
183+
bool OverflowMax;
184+
APInt MaxVal;
185+
if (NSW) {
186+
MaxVal = GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
187+
// (sub nsw nuw) or (sub nsw NegX, PosY)
188+
189+
// None of the subs can overflow at any point, so any common high bits
190+
// will subtract away and result in zeros.
191+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
192+
if (LHS.isNegative() && RHS.isNonNegative())
193+
ForceNegative(KnownOut);
194+
}
195+
if (NUW) {
196+
KnownOut.Zero.clearSignBit();
197+
// (sub nuw X, Y)
198+
MaxVal = GetMaxVal(/*ForNSW=*/false, LHS, RHS, OverflowMax);
199+
200+
// Basically all common high bits between X/Y will cancel out as
201+
// leading zeros.
202+
KnownOut.Zero.setHighBits(MaxVal.countLeadingZeros());
203+
}
204+
} else if (LHS.isNonNegative() && RHS.isNegative()) {
205+
bool OverflowMin;
206+
APInt MinVal = GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
207+
// (sub nsw PosX, NegY)
208+
209+
// Opposite case of above, we must "re-overflow" the signbit, so
210+
// minimal set of high bits will be fixed.
211+
KnownOut.One.setHighBits(MinVal.countLeadingOnes());
212+
ForcePositive(KnownOut);
213+
} else if (!KnownOut.isSignUnknown()) {
214+
// Pass, avoid extra work if we already know the sign bit.
215+
} else if (LHS.isNegative() || RHS.isNonNegative()) {
216+
bool OverflowMax;
217+
(void)GetMaxVal(/*ForNSW=*/true, LHS, RHS, OverflowMax);
218+
// (sub nsw NegX/?X, ?Y/PosY)
219+
if (OverflowMax)
220+
KnownOut.makeNegative();
221+
} else if (LHS.isNonNegative() || RHS.isNegative()) {
222+
bool OverflowMin;
223+
(void)GetMinVal(/*ForNSW=*/true, LHS, RHS, OverflowMin);
224+
// (sub nsw PosX/?X, ?Y/NegY)
225+
if (!OverflowMin)
226+
KnownOut.makeNonNegative();
227+
}
228+
}
229+
230+
// Just return 0 if the nsw/nuw is violated and we have poison.
231+
if (KnownOut.hasConflict()) {
232+
KnownOut.setAllZero();
233+
return KnownOut;
83234
}
84235

85236
return KnownOut;

llvm/test/CodeGen/AArch64/sve-cmp-folds.ll

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,12 @@ define i1 @foo_last(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
114114
; CHECK-LABEL: foo_last:
115115
; CHECK: // %bb.0:
116116
; CHECK-NEXT: ptrue p0.s
117-
; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, z1.s
118-
; CHECK-NEXT: ptest p0, p1.b
119-
; CHECK-NEXT: cset w0, lo
117+
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff
118+
; CHECK-NEXT: whilels p1.s, xzr, x8
119+
; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, z1.s
120+
; CHECK-NEXT: mov z0.s, p0/z, #1 // =0x1
121+
; CHECK-NEXT: lastb w8, p1, z0.s
122+
; CHECK-NEXT: and w0, w8, #0x1
120123
; CHECK-NEXT: ret
121124
%vcond = fcmp oeq <vscale x 4 x float> %a, %b
122125
%vscale = call i64 @llvm.vscale.i64()

llvm/test/CodeGen/AArch64/sve-extract-element.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,11 @@ define i1 @test_lane9_8xi1(<vscale x 8 x i1> %a) #0 {
614614
define i1 @test_last_8xi1(<vscale x 8 x i1> %a) #0 {
615615
; CHECK-LABEL: test_last_8xi1:
616616
; CHECK: // %bb.0:
617-
; CHECK-NEXT: ptrue p1.h
618-
; CHECK-NEXT: ptest p1, p0.b
619-
; CHECK-NEXT: cset w0, lo
617+
; CHECK-NEXT: mov x8, #-1 // =0xffffffffffffffff
618+
; CHECK-NEXT: mov z0.h, p0/z, #1 // =0x1
619+
; CHECK-NEXT: whilels p1.h, xzr, x8
620+
; CHECK-NEXT: lastb w8, p1, z0.h
621+
; CHECK-NEXT: and w0, w8, #0x1
620622
; CHECK-NEXT: ret
621623
%vscale = call i64 @llvm.vscale.i64()
622624
%shl = shl nuw nsw i64 %vscale, 3

llvm/test/CodeGen/AMDGPU/ds-sub-offset.ll

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -137,49 +137,46 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
137137
; CI: ; %bb.0:
138138
; CI-NEXT: s_load_dword s0, s[0:1], 0x0
139139
; CI-NEXT: s_mov_b64 vcc, 0
140-
; CI-NEXT: v_not_b32_e32 v0, v0
141-
; CI-NEXT: v_lshlrev_b32_e32 v0, 2, v0
142-
; CI-NEXT: v_mov_b32_e32 v2, 0x7b
140+
; CI-NEXT: v_mov_b32_e32 v1, 0x7b
141+
; CI-NEXT: v_mov_b32_e32 v2, 0
142+
; CI-NEXT: s_mov_b32 m0, -1
143143
; CI-NEXT: s_waitcnt lgkmcnt(0)
144-
; CI-NEXT: v_mov_b32_e32 v1, s0
145-
; CI-NEXT: v_div_fmas_f32 v1, v1, v1, v1
144+
; CI-NEXT: v_mov_b32_e32 v0, s0
145+
; CI-NEXT: v_div_fmas_f32 v0, v0, v0, v0
146146
; CI-NEXT: s_mov_b32 s0, 0
147-
; CI-NEXT: s_mov_b32 m0, -1
148147
; CI-NEXT: s_mov_b32 s3, 0xf000
149148
; CI-NEXT: s_mov_b32 s2, -1
150149
; CI-NEXT: s_mov_b32 s1, s0
151-
; CI-NEXT: ds_write_b32 v0, v2 offset:65532
152-
; CI-NEXT: buffer_store_dword v1, off, s[0:3], 0
150+
; CI-NEXT: ds_write_b32 v2, v1
151+
; CI-NEXT: buffer_store_dword v0, off, s[0:3], 0
153152
; CI-NEXT: s_waitcnt vmcnt(0)
154153
; CI-NEXT: s_endpgm
155154
;
156155
; GFX9-LABEL: write_ds_sub_max_offset_global_clamp_bit:
157156
; GFX9: ; %bb.0:
158157
; GFX9-NEXT: s_load_dword s0, s[0:1], 0x0
159158
; GFX9-NEXT: s_mov_b64 vcc, 0
160-
; GFX9-NEXT: v_not_b32_e32 v0, v0
161-
; GFX9-NEXT: v_lshlrev_b32_e32 v3, 2, v0
162-
; GFX9-NEXT: v_mov_b32_e32 v4, 0x7b
159+
; GFX9-NEXT: v_mov_b32_e32 v3, 0x7b
160+
; GFX9-NEXT: v_mov_b32_e32 v4, 0
161+
; GFX9-NEXT: ds_write_b32 v4, v3
163162
; GFX9-NEXT: s_waitcnt lgkmcnt(0)
164-
; GFX9-NEXT: v_mov_b32_e32 v1, s0
165-
; GFX9-NEXT: v_div_fmas_f32 v2, v1, v1, v1
163+
; GFX9-NEXT: v_mov_b32_e32 v0, s0
164+
; GFX9-NEXT: v_div_fmas_f32 v2, v0, v0, v0
166165
; GFX9-NEXT: v_mov_b32_e32 v0, 0
167166
; GFX9-NEXT: v_mov_b32_e32 v1, 0
168-
; GFX9-NEXT: ds_write_b32 v3, v4 offset:65532
169167
; GFX9-NEXT: global_store_dword v[0:1], v2, off
170168
; GFX9-NEXT: s_waitcnt vmcnt(0)
171169
; GFX9-NEXT: s_endpgm
172170
;
173171
; GFX10-LABEL: write_ds_sub_max_offset_global_clamp_bit:
174172
; GFX10: ; %bb.0:
175173
; GFX10-NEXT: s_load_dword s0, s[0:1], 0x0
176-
; GFX10-NEXT: v_not_b32_e32 v0, v0
177174
; GFX10-NEXT: s_mov_b32 vcc_lo, 0
178-
; GFX10-NEXT: v_mov_b32_e32 v3, 0x7b
179-
; GFX10-NEXT: v_lshlrev_b32_e32 v2, 2, v0
180175
; GFX10-NEXT: v_mov_b32_e32 v0, 0
176+
; GFX10-NEXT: v_mov_b32_e32 v2, 0x7b
177+
; GFX10-NEXT: v_mov_b32_e32 v3, 0
181178
; GFX10-NEXT: v_mov_b32_e32 v1, 0
182-
; GFX10-NEXT: ds_write_b32 v2, v3 offset:65532
179+
; GFX10-NEXT: ds_write_b32 v3, v2
183180
; GFX10-NEXT: s_waitcnt lgkmcnt(0)
184181
; GFX10-NEXT: v_div_fmas_f32 v4, s0, s0, s0
185182
; GFX10-NEXT: global_store_dword v[0:1], v4, off
@@ -189,13 +186,11 @@ define amdgpu_kernel void @write_ds_sub_max_offset_global_clamp_bit(float %dummy
189186
; GFX11-LABEL: write_ds_sub_max_offset_global_clamp_bit:
190187
; GFX11: ; %bb.0:
191188
; GFX11-NEXT: s_load_b32 s0, s[0:1], 0x0
192-
; GFX11-NEXT: v_not_b32_e32 v0, v0
193189
; GFX11-NEXT: s_mov_b32 vcc_lo, 0
194-
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1)
195-
; GFX11-NEXT: v_dual_mov_b32 v3, 0x7b :: v_dual_lshlrev_b32 v2, 2, v0
196190
; GFX11-NEXT: v_mov_b32_e32 v0, 0
191+
; GFX11-NEXT: v_dual_mov_b32 v2, 0x7b :: v_dual_mov_b32 v3, 0
197192
; GFX11-NEXT: v_mov_b32_e32 v1, 0
198-
; GFX11-NEXT: ds_store_b32 v2, v3 offset:65532
193+
; GFX11-NEXT: ds_store_b32 v3, v2
199194
; GFX11-NEXT: s_waitcnt lgkmcnt(0)
200195
; GFX11-NEXT: v_div_fmas_f32 v4, s0, s0, s0
201196
; GFX11-NEXT: global_store_b32 v[0:1], v4, off dlc

llvm/test/Transforms/InstCombine/fold-log2-ceil-idiom.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ define i64 @log2_ceil_idiom_zext(i32 %x) {
4343
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X]], -1
4444
; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.ctlz.i32(i32 [[TMP1]], i1 false), !range [[RNG0]]
4545
; CHECK-NEXT: [[TMP3:%.*]] = sub nuw nsw i32 32, [[TMP2]]
46-
; CHECK-NEXT: [[RET:%.*]] = zext i32 [[TMP3]] to i64
46+
; CHECK-NEXT: [[RET:%.*]] = zext nneg i32 [[TMP3]] to i64
4747
; CHECK-NEXT: ret i64 [[RET]]
4848
;
4949
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 true)

llvm/test/Transforms/InstCombine/icmp-sub.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ define i1 @test_nuw_nsw_and_unsigned_pred(i64 %x) {
3636

3737
define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
3838
; CHECK-LABEL: @test_nuw_nsw_and_signed_pred(
39-
; CHECK-NEXT: [[Z:%.*]] = icmp sgt i64 [[X:%.*]], 7
39+
; CHECK-NEXT: [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
4040
; CHECK-NEXT: ret i1 [[Z]]
4141
;
4242
%y = sub nuw nsw i64 10, %x
@@ -46,8 +46,7 @@ define i1 @test_nuw_nsw_and_signed_pred(i64 %x) {
4646

4747
define i1 @test_negative_nuw_and_signed_pred(i64 %x) {
4848
; CHECK-LABEL: @test_negative_nuw_and_signed_pred(
49-
; CHECK-NEXT: [[NOTSUB:%.*]] = add nuw i64 [[X:%.*]], -11
50-
; CHECK-NEXT: [[Z:%.*]] = icmp sgt i64 [[NOTSUB]], -4
49+
; CHECK-NEXT: [[Z:%.*]] = icmp ugt i64 [[X:%.*]], 7
5150
; CHECK-NEXT: ret i1 [[Z]]
5251
;
5352
%y = sub nuw i64 10, %x

llvm/test/Transforms/InstCombine/sub.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2367,7 +2367,7 @@ define <2 x i8> @sub_to_and_vector3(<2 x i8> %x) {
23672367
; CHECK-LABEL: @sub_to_and_vector3(
23682368
; CHECK-NEXT: [[SUB:%.*]] = sub nuw <2 x i8> <i8 71, i8 71>, [[X:%.*]]
23692369
; CHECK-NEXT: [[AND:%.*]] = and <2 x i8> [[SUB]], <i8 120, i8 undef>
2370-
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> <i8 44, i8 44>, [[AND]]
2370+
; CHECK-NEXT: [[R:%.*]] = sub nsw <2 x i8> <i8 44, i8 44>, [[AND]]
23712371
; CHECK-NEXT: ret <2 x i8> [[R]]
23722372
;
23732373
%sub = sub nuw <2 x i8> <i8 71, i8 71>, %x

0 commit comments

Comments
 (0)