Skip to content

Commit 7492eec

Browse files
bors[bot]janmarthedalcuviper
authored
Merge #183
183: Get/set n-th bit of `BigUint` and `BigInt` r=cuviper a=janmarthedal This PR implements `bit` and `set_bit` for `BigUint` and `BigInt`. The method names have been chosen to match those of Ramp (https://docs.rs/ramp/0.5.9/ramp/int/struct.Int.html#method.bit). For `BigInt` the implementation uses the number's two's complement representation when the number is negative. This matches what libraries like Ramp or languages like Python do. Resolves #172 Co-authored-by: Jan Marthedal Rasmussen <[email protected]> Co-authored-by: Josh Stone <[email protected]>
2 parents 1368a42 + 69f654e commit 7492eec

File tree

5 files changed

+283
-1
lines changed

5 files changed

+283
-1
lines changed

ci/big_quickcheck/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ edition = "2018"
77
[dependencies]
88
num-integer = "0.1.42"
99
num-traits = "0.2.11"
10-
quickcheck = "0.9"
1110
quickcheck_macros = "0.9"
1211

12+
[dependencies.quickcheck]
13+
version = "0.9"
14+
default-features = false
15+
1316
[dependencies.num-bigint]
1417
features = ["quickcheck"]
1518
path = "../.."

src/bigint.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,6 +3254,129 @@ impl BigInt {
32543254
pub fn trailing_zeros(&self) -> Option<u64> {
32553255
self.data.trailing_zeros()
32563256
}
3257+
3258+
/// Returns whether the bit in position `bit` is set,
3259+
/// using the two's complement for negative numbers
3260+
pub fn bit(&self, bit: u64) -> bool {
3261+
if self.is_negative() {
3262+
// Let the binary representation of a number be
3263+
// ... 0 x 1 0 ... 0
3264+
// Then the two's complement is
3265+
// ... 1 !x 1 0 ... 0
3266+
// where !x is obtained from x by flipping each bit
3267+
if bit >= u64::from(big_digit::BITS) * self.len() as u64 {
3268+
true
3269+
} else {
3270+
let trailing_zeros = self.data.trailing_zeros().unwrap();
3271+
match Ord::cmp(&bit, &trailing_zeros) {
3272+
Less => false,
3273+
Equal => true,
3274+
Greater => !self.data.bit(bit),
3275+
}
3276+
}
3277+
} else {
3278+
self.data.bit(bit)
3279+
}
3280+
}
3281+
3282+
/// Sets or clears the bit in the given position,
3283+
/// using the two's complement for negative numbers
3284+
///
3285+
/// Note that setting/clearing a bit (for positive/negative numbers,
3286+
/// respectively) greater than the current bit length, a reallocation
3287+
/// may be needed to store the new digits
3288+
pub fn set_bit(&mut self, bit: u64, value: bool) {
3289+
match self.sign {
3290+
Sign::Plus => self.data.set_bit(bit, value),
3291+
Sign::NoSign => {
3292+
if value {
3293+
self.data.set_bit(bit, true);
3294+
self.sign = Sign::Plus;
3295+
} else {
3296+
// Clearing a bit for zero is a no-op
3297+
}
3298+
}
3299+
Sign::Minus => {
3300+
let bits_per_digit = u64::from(big_digit::BITS);
3301+
if bit >= bits_per_digit * self.len() as u64 {
3302+
if !value {
3303+
self.data.set_bit(bit, true);
3304+
}
3305+
} else {
3306+
// If the Uint number is
3307+
// ... 0 x 1 0 ... 0
3308+
// then the two's complement is
3309+
// ... 1 !x 1 0 ... 0
3310+
// |-- bit at position 'trailing_zeros'
3311+
// where !x is obtained from x by flipping each bit
3312+
let trailing_zeros = self.data.trailing_zeros().unwrap();
3313+
if bit > trailing_zeros {
3314+
self.data.set_bit(bit, !value);
3315+
} else if bit == trailing_zeros && !value {
3316+
// Clearing the bit at position `trailing_zeros` is dealt with by doing
3317+
// similarly to what `bitand_neg_pos` does, except we start at digit
3318+
// `bit_index`. All digits below `bit_index` are guaranteed to be zero,
3319+
// so initially we have `carry_in` = `carry_out` = 1. Furthermore, we
3320+
// stop traversing the digits when there are no more carries.
3321+
let bit_index = (bit / bits_per_digit).to_usize().unwrap();
3322+
let bit_mask = (1 as BigDigit) << (bit % bits_per_digit);
3323+
let mut digit_iter = self.digits_mut().iter_mut().skip(bit_index);
3324+
let mut carry_in = 1;
3325+
let mut carry_out = 1;
3326+
3327+
let digit = digit_iter.next().unwrap();
3328+
let twos_in = negate_carry(*digit, &mut carry_in);
3329+
let twos_out = twos_in & !bit_mask;
3330+
*digit = negate_carry(twos_out, &mut carry_out);
3331+
3332+
for digit in digit_iter {
3333+
if carry_in == 0 && carry_out == 0 {
3334+
// Exit the loop since no more digits can change
3335+
break;
3336+
}
3337+
let twos = negate_carry(*digit, &mut carry_in);
3338+
*digit = negate_carry(twos, &mut carry_out);
3339+
}
3340+
3341+
if carry_out != 0 {
3342+
// All digits have been traversed and there is a carry
3343+
debug_assert_eq!(carry_in, 0);
3344+
self.digits_mut().push(1);
3345+
}
3346+
} else if bit < trailing_zeros && value {
3347+
// Flip each bit from position 'bit' to 'trailing_zeros', both inclusive
3348+
// ... 1 !x 1 0 ... 0 ... 0
3349+
// |-- bit at position 'bit'
3350+
// |-- bit at position 'trailing_zeros'
3351+
// bit_mask: 1 1 ... 1 0 .. 0
3352+
// This is done by xor'ing with the bit_mask
3353+
let index_lo = (bit / bits_per_digit).to_usize().unwrap();
3354+
let index_hi = (trailing_zeros / bits_per_digit).to_usize().unwrap();
3355+
let bit_mask_lo = big_digit::MAX << (bit % bits_per_digit);
3356+
let bit_mask_hi = big_digit::MAX
3357+
>> (bits_per_digit - 1 - (trailing_zeros % bits_per_digit));
3358+
let digits = self.digits_mut();
3359+
3360+
if index_lo == index_hi {
3361+
digits[index_lo] ^= bit_mask_lo & bit_mask_hi;
3362+
} else {
3363+
digits[index_lo] = bit_mask_lo;
3364+
for index in (index_lo + 1)..index_hi {
3365+
digits[index] = big_digit::MAX;
3366+
}
3367+
digits[index_hi] ^= bit_mask_hi;
3368+
}
3369+
} else {
3370+
// We end up here in two cases:
3371+
// bit == trailing_zeros && value: Bit is already set
3372+
// bit < trailing_zeros && !value: Bit is already cleared
3373+
}
3374+
}
3375+
}
3376+
}
3377+
// The top bit may have been cleared, so normalize
3378+
self.normalize();
3379+
}
32573380
}
32583381

32593382
impl_sum_iter_type!(BigInt);

src/biguint.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2725,6 +2725,43 @@ impl BigUint {
27252725
pub fn count_ones(&self) -> u64 {
27262726
self.data.iter().map(|&d| u64::from(d.count_ones())).sum()
27272727
}
2728+
2729+
/// Returns whether the bit in the given position is set
2730+
pub fn bit(&self, bit: u64) -> bool {
2731+
let bits_per_digit = u64::from(big_digit::BITS);
2732+
if let Some(digit_index) = (bit / bits_per_digit).to_usize() {
2733+
if let Some(digit) = self.data.get(digit_index) {
2734+
let bit_mask = (1 as BigDigit) << (bit % bits_per_digit);
2735+
return (digit & bit_mask) != 0;
2736+
}
2737+
}
2738+
false
2739+
}
2740+
2741+
/// Sets or clears the bit in the given position
2742+
///
2743+
/// Note that setting a bit greater than the current bit length, a reallocation may be needed
2744+
/// to store the new digits
2745+
pub fn set_bit(&mut self, bit: u64, value: bool) {
2746+
// Note: we're saturating `digit_index` and `new_len` -- any such case is guaranteed to
2747+
// fail allocation, and that's more consistent than adding our own overflow panics.
2748+
let bits_per_digit = u64::from(big_digit::BITS);
2749+
let digit_index = (bit / bits_per_digit)
2750+
.to_usize()
2751+
.unwrap_or(core::usize::MAX);
2752+
let bit_mask = (1 as BigDigit) << (bit % bits_per_digit);
2753+
if value {
2754+
if digit_index >= self.data.len() {
2755+
let new_len = digit_index.saturating_add(1);
2756+
self.data.resize(new_len, 0);
2757+
}
2758+
self.data[digit_index] |= bit_mask;
2759+
} else if digit_index < self.data.len() {
2760+
self.data[digit_index] &= !bit_mask;
2761+
// the top bit may have been cleared, so normalize
2762+
self.normalize();
2763+
}
2764+
}
27282765
}
27292766

27302767
fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {

tests/bigint.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,3 +1307,96 @@ fn test_pow() {
13071307
check!(u64);
13081308
check!(usize);
13091309
}
1310+
1311+
#[test]
1312+
fn test_bit() {
1313+
// 12 = (1100)_2
1314+
assert!(!BigInt::from(0b1100u8).bit(0));
1315+
assert!(!BigInt::from(0b1100u8).bit(1));
1316+
assert!(BigInt::from(0b1100u8).bit(2));
1317+
assert!(BigInt::from(0b1100u8).bit(3));
1318+
assert!(!BigInt::from(0b1100u8).bit(4));
1319+
assert!(!BigInt::from(0b1100u8).bit(200));
1320+
assert!(!BigInt::from(0b1100u8).bit(u64::MAX));
1321+
// -12 = (...110100)_2
1322+
assert!(!BigInt::from(-12i8).bit(0));
1323+
assert!(!BigInt::from(-12i8).bit(1));
1324+
assert!(BigInt::from(-12i8).bit(2));
1325+
assert!(!BigInt::from(-12i8).bit(3));
1326+
assert!(BigInt::from(-12i8).bit(4));
1327+
assert!(BigInt::from(-12i8).bit(200));
1328+
assert!(BigInt::from(-12i8).bit(u64::MAX));
1329+
}
1330+
1331+
#[test]
1332+
fn test_set_bit() {
1333+
let mut x: BigInt;
1334+
1335+
// zero
1336+
x = BigInt::zero();
1337+
x.set_bit(200, true);
1338+
assert_eq!(x, BigInt::one() << 200);
1339+
x = BigInt::zero();
1340+
x.set_bit(200, false);
1341+
assert_eq!(x, BigInt::zero());
1342+
1343+
// positive numbers
1344+
x = BigInt::from_biguint(Plus, BigUint::one() << 200);
1345+
x.set_bit(10, true);
1346+
x.set_bit(200, false);
1347+
assert_eq!(x, BigInt::one() << 10);
1348+
x.set_bit(10, false);
1349+
x.set_bit(5, false);
1350+
assert_eq!(x, BigInt::zero());
1351+
1352+
// negative numbers
1353+
x = BigInt::from(-12i8);
1354+
x.set_bit(200, true);
1355+
assert_eq!(x, BigInt::from(-12i8));
1356+
x.set_bit(200, false);
1357+
assert_eq!(
1358+
x,
1359+
BigInt::from_biguint(Minus, BigUint::from(12u8) | (BigUint::one() << 200))
1360+
);
1361+
x.set_bit(6, false);
1362+
assert_eq!(
1363+
x,
1364+
BigInt::from_biguint(Minus, BigUint::from(76u8) | (BigUint::one() << 200))
1365+
);
1366+
x.set_bit(6, true);
1367+
assert_eq!(
1368+
x,
1369+
BigInt::from_biguint(Minus, BigUint::from(12u8) | (BigUint::one() << 200))
1370+
);
1371+
x.set_bit(200, true);
1372+
assert_eq!(x, BigInt::from(-12i8));
1373+
1374+
x = BigInt::from_biguint(Minus, BigUint::one() << 30);
1375+
x.set_bit(10, true);
1376+
assert_eq!(
1377+
x,
1378+
BigInt::from_biguint(Minus, (BigUint::one() << 30) - (BigUint::one() << 10))
1379+
);
1380+
1381+
x = BigInt::from_biguint(Minus, BigUint::one() << 200);
1382+
x.set_bit(40, true);
1383+
assert_eq!(
1384+
x,
1385+
BigInt::from_biguint(Minus, (BigUint::one() << 200) - (BigUint::one() << 40))
1386+
);
1387+
1388+
x = BigInt::from_biguint(Minus, (BigUint::one() << 200) | (BigUint::one() << 100));
1389+
x.set_bit(100, false);
1390+
assert_eq!(
1391+
x,
1392+
BigInt::from_biguint(Minus, (BigUint::one() << 200) | (BigUint::one() << 101))
1393+
);
1394+
1395+
x = BigInt::from_biguint(Minus, (BigUint::one() << 63) | (BigUint::one() << 62));
1396+
x.set_bit(62, false);
1397+
assert_eq!(x, BigInt::from_biguint(Minus, BigUint::one() << 64));
1398+
1399+
x = BigInt::from_biguint(Minus, (BigUint::one() << 200) - BigUint::one());
1400+
x.set_bit(0, false);
1401+
assert_eq!(x, BigInt::from_biguint(Minus, BigUint::one() << 200));
1402+
}

tests/biguint.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,3 +1808,29 @@ fn test_count_ones() {
18081808
let x: BigUint = (BigUint::from(3u8) << 128) | BigUint::from(3u8);
18091809
assert_eq!(x.count_ones(), 4);
18101810
}
1811+
1812+
#[test]
1813+
fn test_bit() {
1814+
assert!(!BigUint::from(0u8).bit(0));
1815+
assert!(!BigUint::from(0u8).bit(100));
1816+
assert!(!BigUint::from(42u8).bit(4));
1817+
assert!(BigUint::from(42u8).bit(5));
1818+
let x: BigUint = (BigUint::from(3u8) << 128) | BigUint::from(3u8);
1819+
assert!(x.bit(129));
1820+
assert!(!x.bit(130));
1821+
}
1822+
1823+
#[test]
1824+
fn test_set_bit() {
1825+
let mut x = BigUint::from(3u8);
1826+
x.set_bit(128, true);
1827+
x.set_bit(129, true);
1828+
assert_eq!(x, (BigUint::from(3u8) << 128) | BigUint::from(3u8));
1829+
x.set_bit(0, false);
1830+
x.set_bit(128, false);
1831+
x.set_bit(130, false);
1832+
assert_eq!(x, (BigUint::from(2u8) << 128) | BigUint::from(2u8));
1833+
x.set_bit(129, false);
1834+
x.set_bit(1, false);
1835+
assert_eq!(x, BigUint::zero());
1836+
}

0 commit comments

Comments
 (0)