@@ -57,181 +57,69 @@ KnownBits KnownBits::computeForAddCarry(
57
57
KnownBits KnownBits::computeForAddSub (bool Add, bool NSW, bool NUW,
58
58
const KnownBits &LHS,
59
59
const KnownBits &RHS) {
60
- // This can be a relatively expensive helper, so optimistically save some
61
- // work.
62
- if (LHS.isUnknown () && RHS.isUnknown ())
63
- return LHS;
64
- KnownBits KnownOut;
65
- if (Add) {
66
- // Sum = LHS + RHS + 0
67
- KnownOut =
68
- ::computeForAddCarry (LHS, RHS, /* CarryZero*/ true , /* CarryOne*/ false );
69
- } else {
70
- // Sum = LHS + ~RHS + 1
71
- KnownBits NotRHS = RHS;
72
- std::swap (NotRHS.Zero , NotRHS.One );
73
- KnownOut = ::computeForAddCarry (LHS, NotRHS, /* CarryZero*/ false ,
74
- /* CarryOne*/ true );
75
- }
76
- if (!NSW && !NUW)
77
- return KnownOut;
78
-
79
- auto GetMinMaxVal = [Add](bool ForNSW, bool ForMax, const KnownBits &L,
80
- const KnownBits &R, bool &OV) {
81
- APInt LVal = ForMax ? L.getMaxValue () : L.getMinValue ();
82
- APInt RVal = Add == ForMax ? R.getMaxValue () : R.getMinValue ();
83
-
84
- if (ForNSW) {
85
- LVal.clearSignBit ();
86
- RVal.clearSignBit ();
87
- }
88
- APInt Res = Add ? LVal.uadd_ov (RVal, OV) : LVal.usub_ov (RVal, OV);
89
- if (ForNSW) {
90
- OV = Res.isSignBitSet ();
91
- Res.clearSignBit ();
92
- if (Res.getBitWidth () > 1 && Res[Res.getBitWidth () - 2 ])
93
- Res.setSignBit ();
60
+ unsigned BitWidth = LHS.getBitWidth ();
61
+ KnownBits KnownOut (LHS.getBitWidth ());
62
+ if (!LHS.isUnknown () && !RHS.isUnknown ()) {
63
+ if (Add) {
64
+ // Sum = LHS + RHS + 0
65
+ KnownOut = ::computeForAddCarry (LHS, RHS, /* CarryZero=*/ true ,
66
+ /* CarryOne=*/ false );
67
+ } else {
68
+ // Sum = LHS + ~RHS + 1
69
+ KnownBits NotRHS = RHS;
70
+ std::swap (NotRHS.Zero , NotRHS.One );
71
+ KnownOut = ::computeForAddCarry (LHS, NotRHS, /* CarryZero=*/ false ,
72
+ /* CarryOne=*/ true );
94
73
}
95
- return Res;
96
- };
97
-
98
- auto GetMaxVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
99
- const KnownBits &R, bool &OV) {
100
- return GetMinMaxVal (ForNSW, /* ForMax=*/ true , L, R, OV);
101
- };
102
-
103
- auto GetMinVal = [&GetMinMaxVal](bool ForNSW, const KnownBits &L,
104
- const KnownBits &R, bool &OV) {
105
- return GetMinMaxVal (ForNSW, /* ForMax=*/ false , L, R, OV);
106
- };
107
-
108
- auto ForceNegative = [](KnownBits &Known) {
109
- Known.Zero .clearSignBit ();
110
- Known.One .setSignBit ();
111
- };
112
-
113
- auto ForcePositive = [](KnownBits &Known) {
114
- Known.One .clearSignBit ();
115
- Known.Zero .setSignBit ();
116
- };
74
+ }
117
75
118
- // Handle add/sub given nsw and/or nuw.
119
- //
120
- // Possible TODO: Add/Sub implementations mirror one another in many ways.
121
- // They could probably be compressed into a single implementation of roughly
122
- // half the total LOC. Leaving seperate for now to increase clarity.
123
- // NB: We handle NSW by essentially treating as nuw of bitwidth - 1 then
124
- // deducing bits based on the known sign result.
125
- if (Add) {
126
- if (NUW || (LHS.isNonNegative () && RHS.isNonNegative ())) {
127
- bool OverflowMin;
128
- APInt MinVal;
76
+ if (NUW) {
77
+ if (Add) {
78
+ // (add nuw X, Y)
79
+ APInt MinVal = LHS.getMinValue ().uadd_sat (RHS.getMinValue ());
129
80
if (NSW) {
130
- MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
131
- // (add nsw nuw) or (add nsw PosX, PosY)
132
-
133
- // None of the adds can end up overflowing, so min consecutive
134
- // highbits in minimum possible of X + Y must all remain set.
135
- KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
136
-
137
- // NSW and Positive arguments leads to positive result.
138
- if (LHS.isNonNegative () && RHS.isNonNegative ())
139
- ForcePositive (KnownOut);
81
+ unsigned NumBits = MinVal.trunc (BitWidth - 1 ).countl_one ();
82
+ KnownOut.One .setBits (BitWidth - 1 - NumBits, BitWidth - 1 );
140
83
}
141
- if (NUW) {
142
- KnownOut. One . clearSignBit ();
143
- // (add nuw X, Y)
144
- MinVal = GetMinVal ( /* ForNSW= */ false , LHS, RHS, OverflowMin );
145
- // Same as (add nsw PosX, PosY), basically since we can't overflow,
146
- // the high bits of minimum possible X + Y must remain set.
147
- KnownOut.One . setHighBits (MinVal. countLeadingOnes () );
84
+ KnownOut. One . setHighBits (MinVal. countLeadingOnes ());
85
+ } else {
86
+ // (sub nuw X, Y)
87
+ APInt MaxVal = LHS. getMaxValue (). usub_sat ( RHS. getMinValue () );
88
+ if (NSW) {
89
+ unsigned NumBits = MaxVal. trunc (BitWidth - 1 ). countl_zero ();
90
+ KnownOut.Zero . setBits (BitWidth - 1 - NumBits, BitWidth - 1 );
148
91
}
149
- } else if (LHS.isNegative () && RHS.isNegative ()) {
150
- bool OverflowMax;
151
- APInt MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
152
- // (add nsw NegX, NegY)
153
-
154
- // We need to re-overflow the signbit, so we are looking for sequence
155
- // of 0s from consecutive overflows.
156
92
KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
157
- ForceNegative (KnownOut);
158
- } else if (!KnownOut.isSignUnknown ()) {
159
- // Pass, avoid extra work if we already know the sign bit.
160
- } else if (LHS.isNonNegative () || RHS.isNonNegative ()) {
161
- bool OverflowMin;
162
- (void )GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
163
- // (add nsw PosX, ?Y)
164
-
165
- // If the minimal possible of X + Y overflows the signbit, then Y must
166
- // have been signed (which will cause unsigned overflow otherwise nsw
167
- // will be violated) leading to unsigned result.
168
- if (OverflowMin)
169
- KnownOut.makeNonNegative ();
170
- } else if (LHS.isNegative () || RHS.isNegative ()) {
171
- bool OverflowMax;
172
- (void )GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
173
- // (add nsw NegX, ?Y)
174
-
175
- // If the maximum possible of X + Y doesn't overflows the signbit,
176
- // then Y must have been unsigned (otherwise nsw violated) so NegX +
177
- // PosY w.o overflowing the signbit results in Negative.
178
- if (!OverflowMax)
179
- KnownOut.makeNegative ();
180
93
}
181
- } else {
182
- if (NUW || (LHS.isNegative () && RHS.isNonNegative ())) {
183
- bool OverflowMax;
184
- APInt MaxVal;
185
- if (NSW) {
186
- MaxVal = GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
187
- // (sub nsw nuw) or (sub nsw NegX, PosY)
188
-
189
- // None of the subs can overflow at any point, so any common high bits
190
- // will subtract away and result in zeros.
191
- KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
192
- if (LHS.isNegative () && RHS.isNonNegative ())
193
- ForceNegative (KnownOut);
194
- }
195
- if (NUW) {
196
- KnownOut.Zero .clearSignBit ();
197
- // (sub nuw X, Y)
198
- MaxVal = GetMaxVal (/* ForNSW=*/ false , LHS, RHS, OverflowMax);
199
-
200
- // Basically all common high bits between X/Y will cancel out as
201
- // leading zeros.
202
- KnownOut.Zero .setHighBits (MaxVal.countLeadingZeros ());
203
- }
204
- } else if (LHS.isNonNegative () && RHS.isNegative ()) {
205
- bool OverflowMin;
206
- APInt MinVal = GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
207
- // (sub nsw PosX, NegY)
94
+ }
208
95
209
- // Opposite case of above, we must "re-overflow" the signbit, so
210
- // minimal set of high bits will be fixed.
211
- KnownOut.One .setHighBits (MinVal.countLeadingOnes ());
212
- ForcePositive (KnownOut);
213
- } else if (!KnownOut.isSignUnknown ()) {
214
- // Pass, avoid extra work if we already know the sign bit.
215
- } else if (LHS.isNegative () || RHS.isNonNegative ()) {
216
- bool OverflowMax;
217
- (void )GetMaxVal (/* ForNSW=*/ true , LHS, RHS, OverflowMax);
218
- // (sub nsw NegX/?X, ?Y/PosY)
219
- if (OverflowMax)
220
- KnownOut.makeNegative ();
221
- } else if (LHS.isNonNegative () || RHS.isNegative ()) {
222
- bool OverflowMin;
223
- (void )GetMinVal (/* ForNSW=*/ true , LHS, RHS, OverflowMin);
224
- // (sub nsw PosX/?X, ?Y/NegY)
225
- if (!OverflowMin)
226
- KnownOut.makeNonNegative ();
96
+ if (NSW) {
97
+ APInt MinVal;
98
+ APInt MaxVal;
99
+ if (Add) {
100
+ // (add nsw X, Y)
101
+ MinVal = LHS.getSignedMinValue ().sadd_sat (RHS.getSignedMinValue ());
102
+ MaxVal = LHS.getSignedMaxValue ().sadd_sat (RHS.getSignedMaxValue ());
103
+ } else {
104
+ // (sub nsw X, Y)
105
+ MinVal = LHS.getSignedMinValue ().ssub_sat (RHS.getSignedMaxValue ());
106
+ MaxVal = LHS.getSignedMaxValue ().ssub_sat (RHS.getSignedMinValue ());
107
+ }
108
+ if (MinVal.isNonNegative ()) {
109
+ unsigned NumBits = MinVal.trunc (BitWidth - 1 ).countl_one ();
110
+ KnownOut.One .setBits (BitWidth - 1 - NumBits, BitWidth - 1 );
111
+ KnownOut.Zero .setSignBit ();
112
+ }
113
+ if (MaxVal.isNegative ()) {
114
+ unsigned NumBits = MaxVal.trunc (BitWidth - 1 ).countl_zero ();
115
+ KnownOut.Zero .setBits (BitWidth - 1 - NumBits, BitWidth - 1 );
116
+ KnownOut.One .setSignBit ();
227
117
}
228
118
}
229
119
230
120
// Just return 0 if the nsw/nuw is violated and we have poison.
231
- if (KnownOut.hasConflict ()) {
121
+ if (KnownOut.hasConflict ())
232
122
KnownOut.setAllZero ();
233
- return KnownOut;
234
- }
235
123
236
124
return KnownOut;
237
125
}
0 commit comments