use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Zero}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Point { Infinity, Point { x: BigInt, y: BigInt }, } #[derive(Debug, Clone)] pub struct WeierstrassCurve { a: BigInt, p: BigInt, } impl WeierstrassCurve { pub fn new(a: BigInt, p: BigInt) -> Self { WeierstrassCurve { a, p } } fn mod_inverse(a: &BigInt, p: &BigInt) -> BigInt { let egcd = a.extended_gcd(p); egcd.x.mod_floor(p) } fn double_point(&self, point: &Point) -> Point { match point { Point::Point { x, y } => { if y.is_zero() { Point::Infinity } else { let three = BigInt::from(3); let two = BigInt::from(2); let lambda = (three * x * x + &self.a) * Self::mod_inverse(&(two * y), &self.p); let lamba_sqr = (&lambda * &lambda).mod_floor(&self.p); let x3 = (&lamba_sqr - x - x).mod_floor(&self.p); let y3 = (&lambda * (x - &x3) - y).mod_floor(&self.p); Point::Point { x: x3, y: y3 } } } Point::Infinity => Point::Infinity, } } pub fn add_points(&self, point1: &Point, point2: &Point) -> Point { match (point1, point2) { (Point::Point { x: x1, y: y1 }, Point::Point { x: x2, y: y2 }) => { if point1 == point2 { self.double_point(point1) } else { let lambda = (y2 - y1) * Self::mod_inverse(&(x2 - x1), &self.p); let x3 = ((&lambda * &lambda) - x1 - x2).mod_floor(&self.p); let y3: BigInt = ((&lambda * (x1 - &x3)) - y1).mod_floor(&self.p); Point::Point { x: x3, y: y3 } } } (Point::Point { x, y }, Point::Infinity) | (Point::Infinity, Point::Point { x, y }) => { Point::Point { x: x.clone(), y: y.clone(), } } (Point::Infinity, Point::Infinity) => Point::Infinity, } } pub fn multiply_point(&self, s: &BigInt, point: &Point) -> Point { let mut res = Point::Infinity; let mut temp = point.clone(); let mut s = s.clone(); while s > BigInt::zero() { if (&s % BigInt::from(2)) == BigInt::one() { res = self.add_points(&res, &temp); } temp = self.double_point(&temp); s >>= 1; } res } }