Skip to content

Commit b146d44

Browse files
committed
Implement extended GCD and modular inverse
1 parent 65eb4c5 commit b146d44

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

src/lib.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ extern crate num_traits as traits;
2424

2525
use core::ops::Add;
2626
use core::mem;
27+
use core::cmp::Ordering;
2728

28-
use traits::{Num, Signed};
29+
use traits::{Num, NumRef, RefNum, Signed};
2930

3031
mod roots;
3132
pub use roots::Roots;
@@ -684,6 +685,57 @@ impl_integer_for_usize!(usize, test_integer_usize);
684685
#[cfg(has_i128)]
685686
impl_integer_for_usize!(u128, test_integer_u128);
686687

688+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
689+
pub struct GcdResult<T> {
690+
/// Greatest common divisor.
691+
pub gcd: T,
692+
/// Coefficients such that: gcd(a, b) = c1*a + c2*b
693+
pub c1: T, pub c2: T,
694+
}
695+
696+
/// Calculate greatest common divisor and the corresponding coefficients.
697+
pub fn extended_gcd<T: Integer + NumRef>(a: T, b: T) -> GcdResult<T>
698+
where for<'a> &'a T: RefNum<T>
699+
{
700+
// Euclid's extended algorithm
701+
let (mut s, mut old_s) = (T::zero(), T::one());
702+
let (mut t, mut old_t) = (T::one(), T::zero());
703+
let (mut r, mut old_r) = (b, a);
704+
705+
while r != T::zero() {
706+
let quotient = &old_r / &r;
707+
old_r = old_r - &quotient * &r; std::mem::swap(&mut old_r, &mut r);
708+
old_s = old_s - &quotient * &s; std::mem::swap(&mut old_s, &mut s);
709+
old_t = old_t - quotient * &t; std::mem::swap(&mut old_t, &mut t);
710+
}
711+
712+
let _quotients = (t, s); // == (a, b) / gcd
713+
714+
GcdResult { gcd: old_r, c1: old_s, c2: old_t }
715+
}
716+
717+
/// Find the standard representation of a (mod n).
718+
pub fn normalize<T: Integer + NumRef>(a: T, n: &T) -> T {
719+
let a = a % n;
720+
match a.cmp(&T::zero()) {
721+
Ordering::Less => a + n,
722+
_ => a,
723+
}
724+
}
725+
726+
/// Calculate the inverse of a (mod n).
727+
pub fn inverse<T: Integer + NumRef + Clone>(a: T, n: &T) -> Option<T>
728+
where for<'a> &'a T: RefNum<T>
729+
{
730+
let GcdResult { gcd, c1: c, .. } = extended_gcd(a, n.clone());
731+
if gcd == T::one() {
732+
Some(normalize(c, n))
733+
} else {
734+
None
735+
}
736+
}
737+
738+
687739
/// An iterator over binomial coefficients.
688740
pub struct IterBinomial<T> {
689741
a: T,
@@ -831,6 +883,24 @@ fn test_lcm_overflow() {
831883
check!(u64, 0x8000_0000_0000_0000, 0x02, 0x8000_0000_0000_0000);
832884
}
833885

886+
#[test]
887+
fn test_extended_gcd() {
888+
assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47});
889+
}
890+
891+
#[test]
892+
fn test_normalize() {
893+
assert_eq!(normalize(10, &7), 3);
894+
assert_eq!(normalize(7, &7), 0);
895+
assert_eq!(normalize(5, &7), 5);
896+
assert_eq!(normalize(-3, &7), 4);
897+
}
898+
899+
#[test]
900+
fn test_inverse() {
901+
assert_eq!(inverse(5, &7).unwrap(), 3);
902+
}
903+
834904
#[test]
835905
fn test_iter_binomial() {
836906
macro_rules! check_simple {

0 commit comments

Comments
 (0)