diff --git a/src/algorithms.rs b/src/algorithms.rs index febd8457..c0a0549f 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -900,6 +900,10 @@ fn biguint_shr2(n: Cow<'_, BigUint>, digits: usize, shift: u8) -> BigUint { biguint_from_vec(data) } +pub(crate) fn invmod(base: &BigInt, modulus: &BigInt) -> BigInt { + base.invmod(modulus) +} + pub(crate) fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering { debug_assert!(a.last() != Some(&0)); debug_assert!(b.last() != Some(&0)); @@ -914,7 +918,8 @@ pub(crate) fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering { mod algorithm_tests { use crate::big_digit::BigDigit; use crate::{BigInt, BigUint}; - use num_traits::Num; + use num_integer::Integer; + use num_traits::{Num, One}; #[test] fn test_sub_sign() { @@ -933,4 +938,47 @@ mod algorithm_tests { assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i); assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i); } + + #[test] + fn test_invmod() { + use super::invmod; + + let one = BigInt::one(); + + let test_modulus = |k| { + let nums = [ + /*[11].to_vec(), */ [43].to_vec(), + [67].to_vec(), + [0x1, 0xa3].to_vec(), + [0x8f, 0xcb, 0x3a, 0xa2, 0x4b, 0x05, 0x5b, 0x4b, 0xfb].to_vec(), + ]; + for i in nums.iter() { + let e = BigInt::from_bytes_be(super::Plus, &i); + + let mut inv_e = invmod(&e, &k); + inv_e *= &e; + inv_e = inv_e.mod_floor(&k); + + assert_eq!(inv_e, one); + } + + // test 1 / (k + 1) % k + let mut e = k.clone(); + e += 1; + let mut inv_e = invmod(&e, &k); + inv_e *= &e; + inv_e = inv_e.mod_floor(&k); + + assert_eq!(inv_e, one); + }; + + // Test even modulus + // TODO: for some reason 11 has no inverse mod 265252859812191058636308480000000 + let mut k = BigInt::from_str_radix("265252859812191058636308480000000", 10).unwrap(); + test_modulus(k.clone()); + + // Test odd modulus + k += 1_u32; + test_modulus(k); + } } diff --git a/src/bigint.rs b/src/bigint.rs index caa6d0eb..17ba00c5 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -3231,6 +3231,223 @@ impl BigInt { BigInt::from_biguint(sign, mag) } + /// Returns `(1 / self) % modulus` using Extended Euclidean Algorithm + /// + /// Panics if self and modulus are both even, either is zero, |self| == modulus, or no inverse is found + /// + /// From wolfSSL implementation: fp_invmod, fp_invmod_slow + /// https://github.com/wolfSSL/wolfssl/blob/master/wolfcrypt/src/tfm.c + pub fn invmod(&self, modulus: &Self) -> Self { + if modulus.is_even() { + return self.invmod_slow(modulus); + } + + if self.is_zero() || modulus.is_zero() { + panic!("base and/or modulus cannot be zero"); + } + + if self.abs() == *modulus { + panic!("cannot invert |a| == modulus"); + } + + let x = modulus; + let y = if self > modulus { + self.mod_floor(modulus) + } else { + self.clone() + }; + + /* 3. u=x, v=y, B=0, D=1 */ + + /* x == modulus, y == value to invert */ + let mut u = x.clone(); + /* we need y = |a| */ + let mut v = y.abs(); + + let mut bb = Self::zero(); + let mut bd = Self::one(); + + // here an infinite loop takes the place of `goto top` + // where a condition calls for `goto top`, simply continue + // + // NOTE: need to be cautious to always break/return, else infinite loop + loop { + /* 4. while u is even do */ + while u.is_even() { + /* 4.1 u = u / 2 */ + u /= 2_u32; + + /* 4.2 if B is odd then */ + if bb.is_odd() { + /* B = (B-x)/2 */ + bb -= x; + } + + /* B = B/2 */ + bb /= 2_u32; + } + + /* 5. while v is even do */ + while v.is_even() { + /* 5.1 v = v/2 */ + v /= 2_u32; + + /* 5.2 if D is odd then */ + if bd.is_odd() { + /* D = (D-x)/2 */ + bd -= x; + } + + /* D = D/2 */ + bd /= 2_u32; + } + + /* 6. if u >= v then */ + if u >= v { + /* u = u - v, B = B - D */ + u -= &v; + bb -= &bd; + } else { + /* v = v - u, D = D - B */ + v -= &u; + bd -= &bb; + } + + /* if u != 0, goto step 4 */ + if !u.is_zero() { + continue; + } + + /* now a = B, b = D, gcd == g*v */ + if !v.is_one() { + // if v != 1, there is no inverse + panic!("no inverse, GCD != 1"); + } + + /* while D is too low */ + while bd.sign() == Minus { + bd += modulus; + } + + /* while D is too big */ + let mod_mag = modulus.magnitude(); + while bd.magnitude() >= mod_mag { + bd -= modulus; + } + + if self.sign() == Minus { + bd = bd.neg(); + } + + /* D is now the inverse */ + break; + } + + bd + } + + fn invmod_slow(&self, modulus: &Self) -> Self { + if self.is_even() && modulus.is_even() { + panic!("base and modulus are both even"); + } + + if self.is_zero() || modulus.is_zero() { + panic!("base and/or modulus cannot be zero"); + } + + let x = self.mod_floor(modulus); + let y = modulus; + + /* 3. u=x, v=y, A=1, B=0, C=0, D-1 */ + let mut u = x.clone(); + let mut v = y.clone(); + let mut ba = Self::one(); + let mut bb = Self::zero(); + let mut bc = Self::zero(); + let mut bd = Self::one(); + + // here an infinite loop takes the place of `goto top` + // where a condition calls for `goto top`, simply continue + // + // NOTE: need to be cautious to always break/return, else infinite loop + loop { + /* 4. while u is even do */ + while u.is_even() { + /* 4.1 u = u / 2 */ + u /= 2_u32; + + /* 4.2 if A or B is odd then */ + if ba.is_odd() || bb.is_odd() { + /* A = (A+y)/2, B = (B-x)/2*/ + // div 2 happens unconditionally below + ba += y; + bb -= &x; + } + + ba /= 2_u32; + bb /= 2_u32; + } + + /* 5. while v is even do */ + while v.is_even() { + /* 5.1 v = v / 2 */ + v /= 2_u32; + + /* 5.2 if C or D is odd then */ + if bc.is_odd() || bd.is_odd() { + /* C = (C+y)/2, D = (D-x)/2 */ + // div 2 happens unconditionally below + bc += y; + bd -= &x; + } + + /* C = C/2, D = D/2 */ + bc /= 2_u32; + bd /= 2_u32; + } + + /* 6. if u >= v then */ + if u >= v { + /* u = u - v, A = A - C, B = B - D */ + u -= &v; + ba -= &bc; + bb -= &bd; + } else { + /* v = v - u, C = C - A, D = D - B */ + v -= &u; + bc -= &ba; + bd -= &bb; + } + + /* if u != 0, goto step 4 */ + if !u.is_zero() { + continue; + } + + /* now a = C, b = D, gcd == g*v */ + if !v.is_one() { + // if v != 1, there is no inverse + panic!("no inverse, GCD != 1"); + } + + /* while C is too low */ + while bc.sign() == Minus { + bc += y; + } + + /* while C is too big */ + let mod_mag = y.magnitude(); + while bc.magnitude() > mod_mag { + bc -= y; + } + + /* C is now the inverse */ + break; + } + + bc + } + /// Returns the truncated principal square root of `self` -- /// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt). pub fn sqrt(&self) -> Self { diff --git a/src/biguint.rs b/src/biguint.rs index 432ff849..8a4239e0 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -32,6 +32,7 @@ mod algorithms; #[path = "monty.rs"] mod monty; +use self::algorithms::invmod; use self::algorithms::{__add2, __sub2rev, add2, sub2, sub2rev}; use self::algorithms::{biguint_shl, biguint_shr}; use self::algorithms::{cmp_slice, fls, ilog2}; @@ -2674,6 +2675,15 @@ impl BigUint { } } + /// Return `(1 / self) % modulus` using Extended Euclidean Algorithm. + /// + /// Panics if self and the modulus are both even, either is zero, |self| == modulus, or no inverse is found + pub fn invmod(&self, modulus: &Self) -> Self { + invmod(&self.clone().into(), &modulus.clone().into()) + .into_parts() + .1 + } + /// Returns the truncated principal square root of `self` -- /// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt) pub fn sqrt(&self) -> Self {