@@ -54,32 +54,183 @@ KnownBits KnownBits::computeForAddCarry(
54
54
LHS, RHS, Carry.Zero .getBoolValue (), Carry.One .getBoolValue ());
55
55
}
56
56
57
- KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool /* NUW*/ ,
58
- const KnownBits &LHS, KnownBits RHS) {
57
+ KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool NUW,
58
+ const KnownBits &LHS,
59
+ const KnownBits &RHS) {
60
+ // This can be a relatively expensive helper, so optimistically save some
61
+ // work.
62
+ if (LHS.isUnknown () && RHS.isUnknown ())
63
+ return LHS;
59
64
KnownBits KnownOut;
60
65
if (Add) {
61
66
// Sum = LHS + RHS + 0
62
- KnownOut = :: computeForAddCarry (
63
- LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
67
+ KnownOut =
68
+ ::computeForAddCarry ( LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
64
69
} else {
65
70
// Sum = LHS + ~RHS + 1
66
- std::swap (RHS.Zero , RHS.One );
67
- KnownOut = ::computeForAddCarry (
68
- LHS, RHS, /* CarryZero*/ false , /* CarryOne*/ true );
71
+ KnownBits NotRHS = RHS;
72
+ std::swap (NotRHS.Zero , NotRHS.One );
73
+ KnownOut = ::computeForAddCarry (LHS, NotRHS, /* CarryZero*/ false ,
74
+ /* CarryOne*/ true );
69
75
}
76
+ if (!NSW && !NUW)
77
+ return KnownOut;
70
78
71
- // Are we still trying to solve for the sign bit?
72
- if (!KnownOut.isNegative () && !KnownOut.isNonNegative ()) {
73
- if (NSW) {
74
- // Adding two non-negative numbers, or subtracting a negative number from
75
- // a non-negative one, can't wrap into negative.
76
- if (LHS.isNonNegative () && RHS.isNonNegative ())
79
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
80
+ const KnownBits &R, bool &OV) {
81
+ APInt LVal = ForMax ? L.getMaxValue () : L.getMinValue ();
82
+ APInt RVal = Add == ForMax ? R.getMaxValue () : R.getMinValue ();
83
+
84
+ if (ForNSW) {
85
+ LVal.clearSignBit ();
86
+ RVal.clearSignBit ();
87
+ }
88
+ APInt Res = Add ? LVal.uadd_ov (RVal, OV) : LVal.usub_ov (RVal, OV);
89
+ if (ForNSW) {
90
+ OV = Res.isSignBitSet ();
91
+ Res.clearSignBit ();
92
+ if (Res.getBitWidth () > 1 && Res[Res.getBitWidth () - 2 ])
93
+ Res.setSignBit ();
94
+ }
95
+ return Res;
96
+ };
97
+
98
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99
+ const KnownBits &R, bool &OV) {
100
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ true , L, R, OV);
101
+ };
102
+
103
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
104
+ const KnownBits &R, bool &OV) {
105
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ false , L, R, OV);
106
+ };
107
+
108
+ auto ForceNegative = [](KnownBits &Known) {
109
+ Known.Zero .clearSignBit ();
110
+ Known.One .setSignBit ();
111
+ };
112
+
113
+ auto ForcePositive = [](KnownBits &Known) {
114
+ Known.One .clearSignBit ();
115
+ Known.Zero .setSignBit ();
116
+ };
117
+
118
+ // Handle add/sub given nsw and/or nuw.
119
+ //
120
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
121
+ // They could probably be compressed into a single implementation of roughly
122
+ // half the total LOC. Leaving seperate for now to increase clarity.
123
+ // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
124
+ // deducing bits based on the known sign result.
125
+ if (Add) {
126
+ if (NUW || (LHS.isNonNegative () && RHS.isNonNegative ())) {
127
+ bool OverflowMin;
128
+ APInt MinVal;
129
+ if (NSW) {
130
+ MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
131
+ // (add nsw nuw) or (add nsw PosX, PosY)
132
+
133
+ // None of the adds can end up overflowing, so min consecutive
134
+ // highbits in minimum possible of X + Y must all remain set.
135
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
136
+
137
+ // NSW and Positive arguments leads to positive result.
138
+ if (LHS.isNonNegative () && RHS.isNonNegative ())
139
+ ForcePositive (KnownOut);
140
+ }
141
+ if (NUW) {
142
+ KnownOut.One .clearSignBit ();
143
+ // (add nuw X, Y)
144
+ MinVal = GetMinVal (/* ForNSW=*/ false , LHS, RHS, OverflowMin);
145
+ // Same as (add nsw PosX, PosY), basically since we can't overflow,
146
+ // the high bits of minimum possible X + Y must remain set.
147
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
148
+ }
149
+ } else if (LHS.isNegative () && RHS.isNegative ()) {
150
+ bool OverflowMax;
151
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
152
+ // (add nsw NegX, NegY)
153
+
154
+ // We need to re-overflow the signbit, so we are looking for sequence
155
+ // of 0s from consecutive overflows.
156
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
157
+ ForceNegative (KnownOut);
158
+ } else if (!KnownOut.isSignUnknown ()) {
159
+ // Pass, avoid extra work if we already know the sign bit.
160
+ } else if (LHS.isNonNegative () || RHS.isNonNegative ()) {
161
+ bool OverflowMin;
162
+ (void )GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
163
+ // (add nsw PosX, ?Y)
164
+
165
+ // If the minimal possible of X + Y overflows the signbit, then Y must
166
+ // have been signed (which will cause unsigned overflow otherwise nsw
167
+ // will be violated) leading to unsigned result.
168
+ if (OverflowMin)
77
169
KnownOut.makeNonNegative ();
78
- // Adding two negative numbers, or subtracting a non-negative number from
79
- // a negative one, can't wrap into non-negative.
80
- else if (LHS.isNegative () && RHS.isNegative ())
170
+ } else if (LHS.isNegative () || RHS.isNegative ()) {
171
+ bool OverflowMax;
172
+ (void )GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
173
+ // (add nsw NegX, ?Y)
174
+
175
+ // If the maximum possible of X + Y doesn't overflows the signbit,
176
+ // then Y must have been unsigned (otherwise nsw violated) so NegX +
177
+ // PosY w.o overflowing the signbit results in Negative.
178
+ if (!OverflowMax)
81
179
KnownOut.makeNegative ();
82
180
}
181
+ } else {
182
+ if (NUW || (LHS.isNegative () && RHS.isNonNegative ())) {
183
+ bool OverflowMax;
184
+ APInt MaxVal;
185
+ if (NSW) {
186
+ MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
187
+ // (sub nsw nuw) or (sub nsw NegX, PosY)
188
+
189
+ // None of the subs can overflow at any point, so any common high bits
190
+ // will subtract away and result in zeros.
191
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
192
+ if (LHS.isNegative () && RHS.isNonNegative ())
193
+ ForceNegative (KnownOut);
194
+ }
195
+ if (NUW) {
196
+ KnownOut.Zero .clearSignBit ();
197
+ // (sub nuw X, Y)
198
+ MaxVal = GetMaxVal (/* ForNSW=*/ false , LHS, RHS, OverflowMax);
199
+
200
+ // Basically all common high bits between X/Y will cancel out as
201
+ // leading zeros.
202
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
203
+ }
204
+ } else if (LHS.isNonNegative () && RHS.isNegative ()) {
205
+ bool OverflowMin;
206
+ APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
207
+ // (sub nsw PosX, NegY)
208
+
209
+ // Opposite case of above, we must "re-overflow" the signbit, so
210
+ // minimal set of high bits will be fixed.
211
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
212
+ ForcePositive (KnownOut);
213
+ } else if (!KnownOut.isSignUnknown ()) {
214
+ // Pass, avoid extra work if we already know the sign bit.
215
+ } else if (LHS.isNegative () || RHS.isNonNegative ()) {
216
+ bool OverflowMax;
217
+ (void )GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
218
+ // (sub nsw NegX/?X, ?Y/PosY)
219
+ if (OverflowMax)
220
+ KnownOut.makeNegative ();
221
+ } else if (LHS.isNonNegative () || RHS.isNegative ()) {
222
+ bool OverflowMin;
223
+ (void )GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
224
+ // (sub nsw PosX/?X, ?Y/NegY)
225
+ if (!OverflowMin)
226
+ KnownOut.makeNonNegative ();
227
+ }
228
+ }
229
+
230
+ // Just return 0 if the nsw/nuw is violated and we have poison.
231
+ if (KnownOut.hasConflict ()) {
232
+ KnownOut.setAllZero ();
233
+ return KnownOut;
83
234
}
84
235
85
236
return KnownOut;
0 commit comments