From 7a35f3df70a5215954e14ee9b1b5105c0bb5709b Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Mon, 2 Jul 2018 16:24:42 +0200 Subject: [PATCH 1/6] Implement extended GCD and modular inverse --- src/lib.rs | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 0281954..7785026 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,8 +23,9 @@ extern crate num_traits as traits; use core::mem; use core::ops::Add; +use core::cmp::Ordering; -use traits::{Num, Signed, Zero}; +use traits::{Num, NumRef, RefNum, Signed, Zero}; mod roots; pub use roots::Roots; @@ -1013,6 +1014,57 @@ impl_integer_for_usize!(usize, test_integer_usize); #[cfg(has_i128)] impl_integer_for_usize!(u128, test_integer_u128); +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct GcdResult { + /// Greatest common divisor. + pub gcd: T, + /// Coefficients such that: gcd(a, b) = c1*a + c2*b + pub c1: T, pub c2: T, +} + +/// Calculate greatest common divisor and the corresponding coefficients. +pub fn extended_gcd(a: T, b: T) -> GcdResult + where for<'a> &'a T: RefNum +{ + // Euclid's extended algorithm + let (mut s, mut old_s) = (T::zero(), T::one()); + let (mut t, mut old_t) = (T::one(), T::zero()); + let (mut r, mut old_r) = (b, a); + + while r != T::zero() { + let quotient = &old_r / &r; + old_r = old_r - "ient * &r; std::mem::swap(&mut old_r, &mut r); + old_s = old_s - "ient * &s; std::mem::swap(&mut old_s, &mut s); + old_t = old_t - quotient * &t; std::mem::swap(&mut old_t, &mut t); + } + + let _quotients = (t, s); // == (a, b) / gcd + + GcdResult { gcd: old_r, c1: old_s, c2: old_t } +} + +/// Find the standard representation of a (mod n). +pub fn normalize(a: T, n: &T) -> T { + let a = a % n; + match a.cmp(&T::zero()) { + Ordering::Less => a + n, + _ => a, + } +} + +/// Calculate the inverse of a (mod n). +pub fn inverse(a: T, n: &T) -> Option + where for<'a> &'a T: RefNum +{ + let GcdResult { gcd, c1: c, .. } = extended_gcd(a, n.clone()); + if gcd == T::one() { + Some(normalize(c, n)) + } else { + None + } +} + + /// An iterator over binomial coefficients. pub struct IterBinomial { a: T, @@ -1169,6 +1221,24 @@ fn test_lcm_overflow() { check!(u64, 0x8000_0000_0000_0000, 0x02, 0x8000_0000_0000_0000); } +#[test] +fn test_extended_gcd() { + assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47}); +} + +#[test] +fn test_normalize() { + assert_eq!(normalize(10, &7), 3); + assert_eq!(normalize(7, &7), 0); + assert_eq!(normalize(5, &7), 5); + assert_eq!(normalize(-3, &7), 4); +} + +#[test] +fn test_inverse() { + assert_eq!(inverse(5, &7).unwrap(), 3); +} + #[test] fn test_iter_binomial() { macro_rules! check_simple { From bba7294251eba32904a6c079e5cbfc6681ab7d24 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Mon, 2 Jul 2018 16:38:09 +0200 Subject: [PATCH 2/6] Implement modular exponentiation --- src/lib.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 7785026..6fa33d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ extern crate std; extern crate num_traits as traits; use core::mem; -use core::ops::Add; +use core::ops::{Add, Neg, Shr}; use core::cmp::Ordering; use traits::{Num, NumRef, RefNum, Signed, Zero}; @@ -1064,6 +1064,30 @@ pub fn inverse(a: T, n: &T) -> Option } } +/// Calculate base^exp (mod modulus). +pub fn powm(base: &T, exp: &T, modulus: &T) -> T + where T: Integer + NumRef + Clone + Neg + Shr, + for<'a> &'a T: RefNum +{ + let zero = T::zero(); + let one = T::one(); + let two = &one + &one; + let mut exp = exp.clone(); + let mut result = one.clone(); + let mut base = base % modulus; + if exp < zero { + exp = -exp; + base = inverse(base, modulus).unwrap(); + } + while exp > zero { + if &exp % &two == one { + result = (result * &base) % modulus; + } + exp = exp >> 1; + base = (&base * &base) % modulus; + } + result +} /// An iterator over binomial coefficients. pub struct IterBinomial { @@ -1239,6 +1263,12 @@ fn test_inverse() { assert_eq!(inverse(5, &7).unwrap(), 3); } +#[test] +fn test_powm() { + // `i64::pow` would overflow. + assert_eq!(powm(&11, &19, &7), 4); +} + #[test] fn test_iter_binomial() { macro_rules! check_simple { From a2d1cacba5e7a76b1ebbe76bb5339ccc256d33ea Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Wed, 4 Jul 2018 10:06:22 +0200 Subject: [PATCH 3/6] Fix no_std build --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6fa33d4..ec1f5e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1033,9 +1033,9 @@ pub fn extended_gcd(a: T, b: T) -> GcdResult while r != T::zero() { let quotient = &old_r / &r; - old_r = old_r - "ient * &r; std::mem::swap(&mut old_r, &mut r); - old_s = old_s - "ient * &s; std::mem::swap(&mut old_s, &mut s); - old_t = old_t - quotient * &t; std::mem::swap(&mut old_t, &mut t); + old_r = old_r - "ient * &r; mem::swap(&mut old_r, &mut r); + old_s = old_s - "ient * &s; mem::swap(&mut old_s, &mut s); + old_t = old_t - quotient * &t; mem::swap(&mut old_t, &mut t); } let _quotients = (t, s); // == (a, b) / gcd From c56a9521e0b588c18a2cc5456682326b07613785 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Wed, 2 Dec 2020 00:03:56 +0100 Subject: [PATCH 4/6] Make adding new fields to `GcdResult` a non-breaking change --- src/lib.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ec1f5e2..96c72d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1020,6 +1020,8 @@ pub struct GcdResult { pub gcd: T, /// Coefficients such that: gcd(a, b) = c1*a + c2*b pub c1: T, pub c2: T, + /// Dummy field to make sure adding new fields is not a breaking change. + seal: (), } /// Calculate greatest common divisor and the corresponding coefficients. @@ -1040,7 +1042,7 @@ pub fn extended_gcd(a: T, b: T) -> GcdResult let _quotients = (t, s); // == (a, b) / gcd - GcdResult { gcd: old_r, c1: old_s, c2: old_t } + GcdResult { gcd: old_r, c1: old_s, c2: old_t, seal: () } } /// Find the standard representation of a (mod n). @@ -1247,7 +1249,7 @@ fn test_lcm_overflow() { #[test] fn test_extended_gcd() { - assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47}); + assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47, seal: () }); } #[test] From 3406bbcaccae4fbd928141e76d96760a0dc47757 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Wed, 2 Dec 2020 00:14:47 +0100 Subject: [PATCH 5/6] Apply `cargo +1.42 fmt` --- src/lib.rs | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 96c72d6..bdb6ef2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,9 +21,9 @@ extern crate std; extern crate num_traits as traits; +use core::cmp::Ordering; use core::mem; use core::ops::{Add, Neg, Shr}; -use core::cmp::Ordering; use traits::{Num, NumRef, RefNum, Signed, Zero}; @@ -1019,14 +1019,16 @@ pub struct GcdResult { /// Greatest common divisor. pub gcd: T, /// Coefficients such that: gcd(a, b) = c1*a + c2*b - pub c1: T, pub c2: T, + pub c1: T, + pub c2: T, /// Dummy field to make sure adding new fields is not a breaking change. seal: (), } /// Calculate greatest common divisor and the corresponding coefficients. pub fn extended_gcd(a: T, b: T) -> GcdResult - where for<'a> &'a T: RefNum +where + for<'a> &'a T: RefNum, { // Euclid's extended algorithm let (mut s, mut old_s) = (T::zero(), T::one()); @@ -1035,14 +1037,22 @@ pub fn extended_gcd(a: T, b: T) -> GcdResult while r != T::zero() { let quotient = &old_r / &r; - old_r = old_r - "ient * &r; mem::swap(&mut old_r, &mut r); - old_s = old_s - "ient * &s; mem::swap(&mut old_s, &mut s); - old_t = old_t - quotient * &t; mem::swap(&mut old_t, &mut t); + old_r = old_r - "ient * &r; + mem::swap(&mut old_r, &mut r); + old_s = old_s - "ient * &s; + mem::swap(&mut old_s, &mut s); + old_t = old_t - quotient * &t; + mem::swap(&mut old_t, &mut t); } let _quotients = (t, s); // == (a, b) / gcd - GcdResult { gcd: old_r, c1: old_s, c2: old_t, seal: () } + GcdResult { + gcd: old_r, + c1: old_s, + c2: old_t, + seal: (), + } } /// Find the standard representation of a (mod n). @@ -1056,7 +1066,8 @@ pub fn normalize(a: T, n: &T) -> T { /// Calculate the inverse of a (mod n). pub fn inverse(a: T, n: &T) -> Option - where for<'a> &'a T: RefNum +where + for<'a> &'a T: RefNum, { let GcdResult { gcd, c1: c, .. } = extended_gcd(a, n.clone()); if gcd == T::one() { @@ -1068,8 +1079,9 @@ pub fn inverse(a: T, n: &T) -> Option /// Calculate base^exp (mod modulus). pub fn powm(base: &T, exp: &T, modulus: &T) -> T - where T: Integer + NumRef + Clone + Neg + Shr, - for<'a> &'a T: RefNum +where + T: Integer + NumRef + Clone + Neg + Shr, + for<'a> &'a T: RefNum, { let zero = T::zero(); let one = T::one(); @@ -1249,7 +1261,15 @@ fn test_lcm_overflow() { #[test] fn test_extended_gcd() { - assert_eq!(extended_gcd(240, 46), GcdResult { gcd: 2, c1: -9, c2: 47, seal: () }); + assert_eq!( + extended_gcd(240, 46), + GcdResult { + gcd: 2, + c1: -9, + c2: 47, + seal: () + } + ); } #[test] From 111ad76459bd3ea2222e90fb515e52193882eb25 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Fri, 22 Jan 2021 19:59:19 +0100 Subject: [PATCH 6/6] Use existing `ExtendedGcd` instead of defining `GcdResult` --- src/lib.rs | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bdb6ef2..c8c551d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1014,19 +1014,8 @@ impl_integer_for_usize!(usize, test_integer_usize); #[cfg(has_i128)] impl_integer_for_usize!(u128, test_integer_u128); -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct GcdResult { - /// Greatest common divisor. - pub gcd: T, - /// Coefficients such that: gcd(a, b) = c1*a + c2*b - pub c1: T, - pub c2: T, - /// Dummy field to make sure adding new fields is not a breaking change. - seal: (), -} - /// Calculate greatest common divisor and the corresponding coefficients. -pub fn extended_gcd(a: T, b: T) -> GcdResult +pub fn extended_gcd(a: T, b: T) -> ExtendedGcd where for<'a> &'a T: RefNum, { @@ -1047,11 +1036,11 @@ where let _quotients = (t, s); // == (a, b) / gcd - GcdResult { + ExtendedGcd { gcd: old_r, - c1: old_s, - c2: old_t, - seal: (), + x: old_s, + y: old_t, + _hidden: (), } } @@ -1069,7 +1058,7 @@ pub fn inverse(a: T, n: &T) -> Option where for<'a> &'a T: RefNum, { - let GcdResult { gcd, c1: c, .. } = extended_gcd(a, n.clone()); + let ExtendedGcd { gcd, x: c, .. } = extended_gcd(a, n.clone()); if gcd == T::one() { Some(normalize(c, n)) } else { @@ -1263,11 +1252,11 @@ fn test_lcm_overflow() { fn test_extended_gcd() { assert_eq!( extended_gcd(240, 46), - GcdResult { + ExtendedGcd { gcd: 2, - c1: -9, - c2: 47, - seal: () + x: -9, + y: 47, + _hidden: () } ); }