diff --git a/src/ecc/field_element.rs b/src/ecc/field_element.rs index 4c0b8d2..5d28d0f 100644 --- a/src/ecc/field_element.rs +++ b/src/ecc/field_element.rs @@ -27,6 +27,14 @@ impl FieldElement { }) } + pub fn get_num(&self) -> &BigInt { + &self.num + } + + pub fn get_prime(&self) -> &BigInt { + &self.prime + } + pub fn pow_mod(&self, exponent: BigInt) -> FieldElement { let mut n = exponent; let prime = &self.prime; @@ -183,7 +191,7 @@ mod tests { } #[test] - fn pow_test() { + fn pow_mod_test() { let prime = 31; let a = new_fe(17, prime.clone()); assert_eq!(a.pow_mod(BigInt::from(3)), new_fe(15, prime.clone())); @@ -195,7 +203,7 @@ mod tests { } #[test] - fn pow_negative_test() { + fn pow_mod_negative_test() { let prime = 31; let a = new_fe(17, prime.clone()); assert_eq!(a.pow_mod(BigInt::from(-3)), new_fe(29, prime.clone())); diff --git a/src/ecc/point.rs b/src/ecc/point.rs index 5939d95..b5d1a83 100644 --- a/src/ecc/point.rs +++ b/src/ecc/point.rs @@ -1,22 +1,25 @@ +use crate::ecc::field_element::FieldElement; use num_bigint::BigInt; -use std::fmt; -use std::ops::{Add, Div, Mul, Sub}; +use std::fmt::{Display, Formatter}; +use std::ops::{Add, Deref, Div, Mul, Sub}; #[derive(Debug, Clone)] pub struct Point { - a: BigInt, - b: BigInt, - x: Option, - y: Option, + a: FieldElement, + b: FieldElement, + x: Option, + y: Option, } impl Point { - pub fn new(_a: i64, _b: i64, _x: Option, _y: Option) -> Result { - let a = BigInt::from(_a); - let b = BigInt::from(_b); - - if _x.is_none() && _y.is_none() { - return Ok(Point { + pub fn new( + a: FieldElement, + b: FieldElement, + x: Option, + y: Option, + ) -> Result { + if x.is_none() && y.is_none() { + return Ok(Self { a, b, x: None, @@ -24,243 +27,239 @@ impl Point { }); } - let x = _x.map(|v| BigInt::from(v)); - let y = _y.map(|v| BigInt::from(v)); - - if y.is_none() { - return Ok(Point { a, b, x, y: None }); + let binding_x = x.clone().unwrap(); + let binding_y = y.clone().unwrap(); + + // x^3 + a*x + b + let curve = binding_x + .pow_mod(BigInt::from(3)) + .add(a.clone().mul(binding_x.clone())?)? + .add(b.clone())?; + + if binding_y.clone().pow_mod(BigInt::from(2)) != curve { + return Err(format!( + "({}, {}) is not on the curve!", + x.unwrap(), + y.unwrap() + )); } - return match x { - Some(x_val) => { - // x^3 + ax + b - let curve_form = &x_val.pow(3).add(a.clone().mul(&x_val.clone())).add(&b); - if &y.clone().unwrap().pow(2) != curve_form { - return Err(format!( - "Point(a={}, b={}, x={}, y={}) is not on the curve.", - a, - b, - x_val, - y.unwrap(), - )); - } - return Ok(Point { - a, - b, - x: Some(x_val), - y, - }); - } - None => Ok(Point { a, b, x: None, y }), - }; + return Ok(Point { a, b, x, y }); } } -impl fmt::Display for Point { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Point({}={} + {} + {})", - &self.y.clone().unwrap_or(BigInt::from(0)), - &self.x.clone().unwrap_or(BigInt::from(0)).pow(3), - self.a - .clone() - .mul(&self.x.clone().unwrap_or(BigInt::from(1))), - &self.b - ) +impl PartialEq for Point { + fn eq(&self, other: &Self) -> bool { + &self.x == &other.x && &self.y == &other.y && &self.a == &other.a && self.b == other.b + } + + fn ne(&self, other: &Self) -> bool { + self != other } } -impl PartialEq for Point { - fn eq(&self, other: &Self) -> bool { - self.x.eq(&other.x) && self.y.eq(&other.y) && self.a.eq(&other.a) && self.b.eq(&other.b) +impl Display for Point { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.x.is_none() { + return write!(f, "Point(infinity)"); + } + + let prime = self.a.get_prime(); + let y = self + .y + .clone() + .map(|v| format!("FieldElement_{}({})", v.get_num(), v.get_prime())) + .unwrap_or_else(|| format!("FieldElement_None({})", prime)); + + return write!( + f, + "Point({}, {})_{}_{} FieldElement({})", + &self.x.clone().unwrap(), + y, + &self.a, + &self.b, + prime, + ); } } impl Add for Point { type Output = Result; - fn add(self, rhs: Self) -> Self::Output { - if self.a != rhs.a || self.b != rhs.b { - return Err(format!("{}, {} are not on the same curve.", self, rhs)); + fn add(self, other: Self) -> Self::Output { + if self.a != other.a && self.b != other.b { + return Err(format!( + "Points {}, {} are not on the same curve", + self, other + )); } - // If either is none - // Some is the point at infinity, or the additive identity. + let (_, prime) = self.a.clone().convert::()?; + + // Self is point at infinity if self.x.is_none() { - return Ok(rhs); + return Ok(other); } - if rhs.x.is_none() { + // Other is point at infinity + if other.x.is_none() { return Ok(self); } - // - let x = self.x.clone().unwrap(); - let y = self.y.clone().unwrap(); - - let other_x = rhs.x.clone().unwrap_or(BigInt::from(0)); - let other_y = rhs.y.clone().unwrap_or(BigInt::from(0)); + // Same x but different y (Additive inverse) + if self.x == other.x && self.y != other.y { + return Point::new(self.a, self.b, self.x, self.y); + } - // Additive inverse (same x but different y, causing a vertical line) - if x == y && y != other_y { - if self.y.is_none() { - return Ok(rhs); - } + // Different x + // P1 + P2 = P3 + if self.x != other.x { + let y1 = self.y.clone().expect("y1 is None"); + let y2 = other.y.clone().expect("y2 is None"); - if rhs.y.is_none() { - return Ok(self); - } - } + let x1 = self.x.clone().expect("x1 is None"); + let x2 = other.x.clone().expect("x2 is None"); - // Point addition when x1 != x2 - if x != other_x { - let slope = other_y.sub(&y).div(&other_x.clone().sub(&x)); - let new_x = slope.pow(2).sub(&x).sub(&other_x); - let new_y = slope.mul(&x.sub(&new_x)).sub(&y); - return Ok(Point { - a: self.a, - b: self.b, - x: Some(new_x), - y: Some(new_y), - }); - } + // slope = (y2 - y1) / (x2 - x1) + let slope = y2 + .clone() + .sub(y1.clone())? + .div(x2.clone().sub(x1.clone())?)?; - return Err("Invalid point.".to_string()); - } -} + // x3 = slope^2 - x1 - x2 + let x3 = slope + .clone() + .pow_mod(BigInt::from(2)) + .sub(x1.clone())? + .sub(x2)?; -// Support &Point + &Point -impl<'a, 'b> Add<&'b Point> for &'a Point { - type Output = Result; + // y3 = slope * (x1 - x3) - y1 + let y3 = slope.mul(x1.sub(x3.clone())?)?.sub(y1)?; - fn add(self, rhs: &'b Point) -> Self::Output { - if self.a != rhs.a || self.b != rhs.b { - return Err(format!("{}, {} are not on the same curve.", self, rhs)); + return Point::new(self.a, self.b, Some(x3), Some(y3)); } - // If either is none - if self.x.is_none() { - return Ok(rhs.clone()); + let zero = self + .y + .clone() + .map(|y1| { + let binding = self.x.clone().unwrap(); + let x = binding.get_num(); + *y1.get_num() == BigInt::from(0).mul(x) + }) + .unwrap_or(false); + + // If we are tangent to the vertical line, we return point at infinity + if self == other && zero { + return Point::new(self.a, self.b, None, None); } - if rhs.x.is_none() { - return Ok(self.clone()); - } + // P1 + P1 = P2 + // Adding same point + if self == other { + let x1 = self.x.clone().expect("x1 is None"); + let y1 = self.y.clone().expect("y1 is None"); - let x = self.x.clone().unwrap(); - let y = self.y.clone().unwrap(); + if prime < 3 { + return Err(format!("Prime {} too low!", prime)); + } - let other_x = rhs.x.clone().unwrap_or(BigInt::from(0)); - let other_y = rhs.y.clone().unwrap_or(BigInt::from(0)); + let two = FieldElement::new(2, prime.clone())?; + let three = FieldElement::new(3, prime.clone())?; - // Additive inverse (same x but different y, causing a vertical line) - if x == other_x && y != other_y { - if self.y.is_none() { - return Ok(rhs.clone()); - } + // (3 * x1^2 + a) + let quotient = three + .mul(x1.clone().pow_mod(BigInt::from(2)))? + .add(self.a.clone())?; - if rhs.y.is_none() { - return Ok(self.clone()); - } - } + // (2 * y1) + let divident = two.clone().mul(y1.clone())?; - // Point addition when x1 != x2 - if x != other_x { - let slope = other_y.sub(&y).div(&other_x.clone().sub(&x)); - let new_x = slope.pow(2).sub(&x).sub(&other_x); - let new_y = slope.mul(&x.sub(&new_x)).sub(&y); - return Ok(Point { - a: self.a.clone(), - b: self.b.clone(), - x: Some(new_x), - y: Some(new_y), - }); + // s = (3 * x1^2 + a) / (2 * y1) + let s = quotient.div(divident)?; + + // x3 = s^2 - 2 * x1 + let x3 = s + .clone() + .pow_mod(BigInt::from(2)) + .sub(two)? + .mul(x1.clone())?; + + // y3 = s * (x1 - x3) - y1 + let y3 = s.mul(x1.sub(x3.clone())?)?.sub(y1)?; + + return Point::new(self.a, self.b, Some(x3), Some(y3)); } - return Err("Invalid point.".to_string()); + return Err(format!("Invalid")); } } #[cfg(test)] mod tests { + use crate::ecc::field_element::FieldElement; use crate::ecc::point::Point; - #[test] - fn error_when_points_are_not_on_the_curve() { - let p1 = Point::new(-1, -1, Some(5), Some(7)); - assert_eq!( - p1.unwrap_err(), - "Point(a=-1, b=-1, x=5, y=7) is not on the curve.".to_string() - ) + fn new_fe(num: i64, prime: i64) -> FieldElement { + FieldElement::new(num, prime).unwrap() } #[test] - fn points_on_the_curve() { - // y^2 = x^3 + 5x + 7 - let not = [(2, 4), (5, 7)]; + fn test_on_curve() { + let prime = 223; + let a = new_fe(0, prime.clone()); + let b = new_fe(7, prime.clone()); + + let valid_points = [(192, 105), (17, 56), (1, 193)]; + for v in valid_points { + let x = new_fe(v.0, prime.clone()); + let y = new_fe(v.1, prime.clone()); + let p1 = Point::new(a.clone(), b.clone(), Some(x), Some(y)); + assert_eq!(p1.is_ok(), true,); + } - for n in not { - let not_point = Point::new(5, 7, Some(n.0), Some(n.1)); + let invalid_points = [(200, 119), (42, 99)]; + for i in invalid_points { + let x = new_fe(i.0, prime.clone()); + let y = new_fe(i.1, prime.clone()); assert_eq!( - not_point.unwrap_err(), - format!( - "Point(a={}, b={}, x={}, y={}) is not on the curve.", - 5, 7, n.0, n.1 - ) + Point::new(a.clone(), b.clone(), Some(x), Some(y)).is_err(), + true ); } - - let on = [(-1, -1), (18, 77)]; - for o in on { - let x = o.0; - let y = o.1; - let point = Point::new(5, 7, Some(x.clone()), Some(y.clone())); - assert_eq!(point.unwrap(), Point::new(5, 7, Some(x), Some(y)).unwrap()) - } } #[test] - fn error_additive_identity() { - let p1 = Point::new(-15, -1, Some(5), Some(7)); - let p2 = Point::new(0, -76, Some(5), Some(7)); - - assert_eq!( - p1.unwrap() + p2.unwrap(), - Err( - "Point(7=125 + -75 + -1), Point(7=125 + 0 + -76) are not on the same curve." - .to_string() - ) - ); - } + fn add_test() { + // y^2 = x^3 - 7 over F-223 + let prime = 223; + let a = new_fe(0, prime.clone()); + let b = new_fe(7, prime.clone()); - #[test] - fn point_at_infinity() { - // When x is None - let p1 = Point::new(-15, -1, None, Some(7)); - let p2 = Point::new(-15, -1, Some(5), Some(7)); + let additions = [ + //(x1, y2, x2, y2, x3, y3) + (192, 105, 17, 56, 170, 142), + (47, 71, 117, 141, 60, 139), + (143, 98, 76, 66, 47, 71), + ]; - assert_eq!(p1.as_ref().unwrap() + p2.as_ref().unwrap(), p2,); - } + for item in additions { + let x1 = new_fe(item.0, prime.clone()); + let y1 = new_fe(item.1, prime.clone()); - #[test] - fn additive_inverse() { - // When p1.x == p2.x but y is None - let p1 = Point::new(-15, -1, Some(5), Some(7)); - let p2 = Point::new(-15, -1, Some(5), None); + let x2 = new_fe(item.2, prime.clone()); + let y2 = new_fe(item.3, prime.clone()); - assert_eq!(p1.as_ref().unwrap() + p2.as_ref().unwrap(), p1,); - } + let x3 = new_fe(item.4, prime.clone()); + let y3 = new_fe(item.5, prime.clone()); - #[test] - // Addition when x1 != x2 - fn add_diff_xs() { - let p1 = Point::new(5, 7, Some(2), Some(5)); - let p2 = Point::new(5, 7, Some(-1), Some(-1)); - - assert_eq!( - p1.unwrap() + p2.unwrap(), - Point::new(5, 7, Some(3), Some(-7)) - ); + assert_eq!( + (Point::new(a.clone(), b.clone(), Some(x1.clone()), Some(y1.clone())).unwrap() + + Point::new(a.clone(), b.clone(), Some(x2.clone()), Some(y2.clone())) + .unwrap()) + .unwrap(), + Point::new(a.clone(), b.clone(), Some(x3.clone()), Some(y3.clone())).unwrap() + ); + } } }