Skip to content

Commit

Permalink
refactor(scalar): Extract binary exponentiation logic into macro
Browse files Browse the repository at this point in the history
- Add `binary_exponentiation` macro to encapsulate logic for exponentiation
- Update Mul implementations for Scalar to use new macro
  • Loading branch information
Elvis339 committed Sep 21, 2023
1 parent 9ab4fe8 commit a79dd2b
Showing 1 changed file with 56 additions and 62 deletions.
118 changes: 56 additions & 62 deletions src/ecc/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::ecc::error::FieldElementError;
use std::ops::Mul;

use crate::ecc::field_element::FieldElement;
use crate::ecc::point::Point;

#[derive(Debug, Clone)]
pub struct Scalar {
Expand All @@ -15,67 +14,74 @@ impl Scalar {
}
}

impl Mul<FieldElement> for Scalar {
type Output = Result<FieldElement, FieldElementError>;

// Naive impl
fn mul(self, rhs: FieldElement) -> Self::Output {
let prime = rhs.get_prime().clone();
let num = rhs.get_num().clone();
let mut result = FieldElement::construct_from(num.clone(), prime.clone()).unwrap();
macro_rules! binary_exponentiation {
($lhs_value:expr, $rhs:expr) => {{
let prime = $rhs.get_prime();
let num = $rhs.get_num();

let end = self.value;
let mut coef = $lhs_value - 1;
let mut current = $rhs.clone();
let mut result = FieldElement::construct_from(num.clone(), prime.clone())?;

for _ in 1..end {
result = (result + rhs.clone())?
while coef > 0 {
if coef & 1 == 1 {
result = (&result + &current).unwrap();
}
current = (&current + &current).unwrap();
coef >>= 1;
}

Ok(result)
}};
}

impl Mul<FieldElement> for Scalar {
type Output = Result<FieldElement, FieldElementError>;

fn mul(self, rhs: FieldElement) -> Self::Output {
binary_exponentiation!(self.value, rhs)
}
}

impl Mul<&FieldElement> for Scalar {
type Output = Result<FieldElement, FieldElementError>;

// Naive impl
fn mul(self, rhs: &FieldElement) -> Self::Output {
let prime = rhs.get_prime();
let num = rhs.get_num();

let mut result = FieldElement::construct_from(num.clone(), prime.clone())?;

let end = self.value;

for _ in 1..end {
result = (result + rhs)?
}

Ok(result)
binary_exponentiation!(self.value, rhs)
}
}

impl Mul<Point> for Scalar {
type Output = Result<Point, FieldElementError>;

// Naive impl
fn mul(self, rhs: Point) -> Self::Output {
let mut result = Point::new(rhs.a.clone(), rhs.b.clone(), None, None).unwrap();

let end = self.value;

for _ in 0..end {
result = (result + rhs.clone())?
}
impl<'a, 'b> Mul<&'b FieldElement> for &'a Scalar {
type Output = Result<FieldElement, FieldElementError>;

Ok(result)
fn mul(self, rhs: &'b FieldElement) -> Self::Output {
binary_exponentiation!(self.value, rhs)
}
}

// impl Mul<Point> for Scalar {
// type Output = Result<Point, FieldElementError>;
//
// fn mul(self, rhs: Point) -> Self::Output {
// let mut coef = self.value;
// let mut current = rhs.clone();
// let mut result = Point::new(rhs.a.clone(), rhs.b.clone(), None, None)?;
//
// while coef > 0 {
// if coef & 1 == 1 {
// result = (&result + &current).unwrap();
// }
// current = (&current + &current).unwrap();
// coef >>= 1;
// }
//
// Ok(result)
// }
// }

#[cfg(test)]
mod tests {
use crate::ecc::field_element::FieldElement;
use crate::ecc::point::Point;
use crate::ecc::scalar::Scalar;
use super::*;

fn new_fe(num: i64, prime: i64) -> FieldElement {
FieldElement::new(num, prime).unwrap()
Expand All @@ -84,27 +90,15 @@ mod tests {
#[test]
fn multiply_field_element() {
let fe = new_fe(15, 223);
let res = Scalar::new(2) * fe;
assert_eq!(res.unwrap(), new_fe(30, 223))
}
let scalar = Scalar::new(2);

#[test]
fn scalar_multiplication_point() {
let prime = 223;
let a = new_fe(0, prime.clone());
let b = new_fe(7, prime.clone());
let x = new_fe(47, prime.clone());
let y = new_fe(71, prime.clone());
let p = Point::new(a.clone(), b.clone(), Some(x), Some(y)).unwrap();

assert_eq!(
Scalar::new(10) * p,
Point::new(
a,
b,
Some(new_fe(154, prime.clone())),
Some(new_fe(150, prime.clone())),
)
)
let res = (scalar.clone() * fe.clone()).unwrap();
assert_eq!(res, new_fe(30, 223));

let res_ref = (scalar.clone() * &fe).unwrap();
assert_eq!(res_ref, new_fe(30, 223));

let res_ref_ref = (&scalar * &fe).unwrap();
assert_eq!(res_ref_ref, new_fe(30, 223));
}
}

0 comments on commit a79dd2b

Please sign in to comment.