@@ -54,34 +54,184 @@ KnownBits KnownBits::computeForAddCarry(
54
54
LHS, RHS, Carry.Zero .getBoolValue (), Carry.One .getBoolValue ());
55
55
}
56
56
57
- KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool /* NUW*/ ,
57
+ KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool NUW,
58
58
const KnownBits &LHS, KnownBits RHS) {
59
59
KnownBits KnownOut;
60
60
if (Add) {
61
61
// Sum = LHS + RHS + 0
62
- KnownOut = :: computeForAddCarry (
63
- LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
62
+ KnownOut =
63
+ ::computeForAddCarry ( LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
64
64
} else {
65
65
// Sum = LHS + ~RHS + 1
66
- std::swap (RHS.Zero , RHS.One );
67
- KnownOut = ::computeForAddCarry (
68
- LHS, RHS, /* CarryZero*/ false , /* CarryOne*/ true );
66
+ KnownBits NotRHS = RHS;
67
+ std::swap (NotRHS.Zero , NotRHS.One );
68
+ KnownOut = ::computeForAddCarry (LHS, NotRHS, /* CarryZero*/ false ,
69
+ /* CarryOne*/ true );
69
70
}
71
+ if (!NSW && !NUW)
72
+ return KnownOut;
70
73
71
- // Are we still trying to solve for the sign bit?
72
- if (!KnownOut.isNegative () && !KnownOut.isNonNegative ()) {
74
+ // We truncate out the signbit during nsw handling so just handle this special
75
+ // case to avoid dealing with it later.
76
+ if (LHS.getBitWidth () == 1 ) {
77
+ return LHS | RHS;
78
+ }
79
+
80
+ auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
81
+ const KnownBits &R, bool &OV) {
82
+ APInt LVal = ForMax ? L.getMaxValue () : L.getMinValue ();
83
+ APInt RVal = Add == ForMax ? R.getMaxValue () : R.getMinValue ();
84
+
85
+ if (ForNSW) {
86
+ LVal = LVal.trunc (LVal.getBitWidth () - 1 );
87
+ RVal = RVal.trunc (RVal.getBitWidth () - 1 );
88
+ }
89
+ APInt Res = Add ? LVal.uadd_ov (RVal, OV) : LVal.usub_ov (RVal, OV);
90
+ if (ForNSW)
91
+ Res = Res.sext (Res.getBitWidth () + 1 );
92
+ return Res;
93
+ };
94
+
95
+ auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
96
+ const KnownBits &R, bool &OV) {
97
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ true , L, R, OV);
98
+ };
99
+
100
+ auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
101
+ const KnownBits &R, bool &OV) {
102
+ return GetMinMaxVal (ForNSW, /* ForMax=*/ false , L, R, OV);
103
+ };
104
+
105
+ std::optional<bool > Negative;
106
+ bool Poison = false ;
107
+ // Handle add/sub given nsw and/or nuw.
108
+ //
109
+ // Possible TODO: Add/Sub implementations mirror one another in many ways.
110
+ // They could probably be compressed into a single implementation of roughly
111
+ // half the total LOC. Leaving seperate for now to increase clarity.
112
+ // NB: We handle NSW by truncating sign bits then deducing bits based on
113
+ // the known sign result.
114
+ if (Add) {
115
+ if (NSW) {
116
+ bool OverflowMax, OverflowMin;
117
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
118
+ APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
119
+
120
+ if (NUW || (LHS.isNonNegative () && RHS.isNonNegative ())) {
121
+ // (add nuw) or (add nsw PosX, PosY)
122
+
123
+ // None of the adds can end up overflowing, so min consecutive highbits
124
+ // in minimum possible of X + Y must all remain set.
125
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
126
+
127
+ // NSW and Positive arguments leads to positive result.
128
+ if (LHS.isNonNegative () && RHS.isNonNegative ())
129
+ Negative = false ;
130
+ else
131
+ KnownOut.One .clearSignBit ();
132
+
133
+ Poison = OverflowMin;
134
+ } else if (LHS.isNegative () && RHS.isNegative ()) {
135
+ // (add nsw NegX, NegY)
136
+
137
+ // We need to re-overflow the signbit, so we are looking for sequence of
138
+ // 0s from consecutive overflows.
139
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
140
+ Negative = true ;
141
+ Poison = !OverflowMax;
142
+ } else if (LHS.isNonNegative () || RHS.isNonNegative ()) {
143
+ // (add nsw PosX, ?Y)
144
+
145
+ // If the minimal possible of X + Y overflows the signbit, then Y must
146
+ // have been signed (which will cause unsigned overflow otherwise nsw
147
+ // will be violated) leading to unsigned result.
148
+ if (OverflowMin)
149
+ Negative = false ;
150
+ } else if (LHS.isNegative () || RHS.isNegative ()) {
151
+ // (add nsw NegX, ?Y)
152
+
153
+ // If the maximum possible of X + Y doesn't overflows the signbit, then
154
+ // Y must have been unsigned (otherwise nsw violated) so NegX + PosY w.o
155
+ // overflowing the signbit results in Negative.
156
+ if (!OverflowMax)
157
+ Negative = true ;
158
+ }
159
+ }
160
+ if (NUW) {
161
+ // (add nuw X, Y)
162
+ bool OverflowMax, OverflowMin;
163
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ false , LHS, RHS, OverflowMax);
164
+ APInt MinVal = GetMinVal (/* ForNSW=*/ false , LHS, RHS, OverflowMin);
165
+ // Same as (add nsw PosX, PosY), basically since we can't overflow, the
166
+ // high bits of minimum possible X + Y must remain set.
167
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
168
+ Poison = OverflowMin;
169
+ }
170
+ } else {
73
171
if (NSW) {
74
- // Adding two non-negative numbers, or subtracting a negative number from
75
- // a non-negative one, can't wrap into negative.
76
- if (LHS.isNonNegative () && RHS.isNonNegative ())
77
- KnownOut.makeNonNegative ();
78
- // Adding two negative numbers, or subtracting a non-negative number from
79
- // a negative one, can't wrap into non-negative.
80
- else if (LHS.isNegative () && RHS.isNegative ())
81
- KnownOut.makeNegative ();
172
+ bool OverflowMax, OverflowMin;
173
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
174
+ APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
175
+ if (NUW || (LHS.isNegative () && RHS.isNonNegative ())) {
176
+ // (sub nuw) or (sub nsw NegX, PosY)
177
+
178
+ // None of the subs can overflow at any point, so any common high bits
179
+ // will subtract away and result in zeros.
180
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
181
+ if (LHS.isNegative () && RHS.isNonNegative ())
182
+ Negative = true ;
183
+ else
184
+ KnownOut.Zero .clearSignBit ();
185
+
186
+ Poison = OverflowMax;
187
+ } else if (LHS.isNonNegative () && RHS.isNegative ()) {
188
+ // (sub nsw PosX, NegY)
189
+ Negative = false ;
190
+
191
+ // Opposite case of above, we must "re-overflow" the signbit, so minimal
192
+ // set of high bits will be fixed.
193
+ KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
194
+ Poison = !OverflowMin;
195
+ } else if (LHS.isNegative () || RHS.isNonNegative ()) {
196
+ // (sub nsw NegX/?X, ?Y/PosY)
197
+ if (OverflowMax)
198
+ Negative = true ;
199
+ } else if (LHS.isNonNegative () || RHS.isNegative ()) {
200
+ // (sub nsw PosX/?X, ?Y/NegY)
201
+ if (!OverflowMin)
202
+ Negative = false ;
203
+ }
204
+ }
205
+ if (NUW) {
206
+ // (sub nuw X, Y)
207
+ bool OverflowMax, OverflowMin;
208
+ APInt MaxVal = GetMaxVal (/* ForNSW=*/ false , LHS, RHS, OverflowMax);
209
+ APInt MinVal = GetMinVal (/* ForNSW=*/ false , LHS, RHS, OverflowMin);
210
+
211
+ // Basically all common high bits between X/Y will cancel out as leading
212
+ // zeros.
213
+ KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
214
+ Poison = OverflowMax;
82
215
}
83
216
}
84
217
218
+ // Handle any proven sign bit.
219
+ if (Negative.has_value ()) {
220
+ KnownOut.One .clearSignBit ();
221
+ KnownOut.Zero .clearSignBit ();
222
+
223
+ if (*Negative)
224
+ KnownOut.makeNegative ();
225
+ else
226
+ KnownOut.makeNonNegative ();
227
+ }
228
+
229
+ // Just return 0 if the nsw/nuw is violated and we have poison.
230
+ if (Poison || KnownOut.hasConflict ()) {
231
+ KnownOut.setAllZero ();
232
+ return KnownOut;
233
+ }
234
+
85
235
return KnownOut;
86
236
}
87
237
0 commit comments