88 lines
2.6 KiB
Rust
88 lines
2.6 KiB
Rust
use num_bigint::BigInt;
|
|
use num_integer::Integer;
|
|
use num_traits::{One, Zero};
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum Point {
|
|
Infinity,
|
|
Point { x: BigInt, y: BigInt },
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct WeierstrassCurve {
|
|
a: BigInt,
|
|
p: BigInt,
|
|
}
|
|
|
|
impl WeierstrassCurve {
|
|
pub fn new(a: BigInt, p: BigInt) -> Self {
|
|
WeierstrassCurve { a, p }
|
|
}
|
|
|
|
fn mod_inverse(a: &BigInt, p: &BigInt) -> BigInt {
|
|
let egcd = a.extended_gcd(p);
|
|
egcd.x.mod_floor(p)
|
|
}
|
|
|
|
fn double_point(&self, point: &Point) -> Point {
|
|
match point {
|
|
Point::Point { x, y } => {
|
|
if y.is_zero() {
|
|
Point::Infinity
|
|
} else {
|
|
let three = BigInt::from(3);
|
|
let two = BigInt::from(2);
|
|
|
|
let lambda = (three * x * x + &self.a) * Self::mod_inverse(&(two * y), &self.p);
|
|
let lamba_sqr = (&lambda * &lambda).mod_floor(&self.p);
|
|
let x3 = (&lamba_sqr - x - x).mod_floor(&self.p);
|
|
let y3 = (&lambda * (x - &x3) - y).mod_floor(&self.p);
|
|
|
|
Point::Point { x: x3, y: y3 }
|
|
}
|
|
}
|
|
Point::Infinity => Point::Infinity,
|
|
}
|
|
}
|
|
|
|
pub fn add_points(&self, point1: &Point, point2: &Point) -> Point {
|
|
match (point1, point2) {
|
|
(Point::Point { x: x1, y: y1 }, Point::Point { x: x2, y: y2 }) => {
|
|
if point1 == point2 {
|
|
self.double_point(point1)
|
|
} else {
|
|
let lambda = (y2 - y1) * Self::mod_inverse(&(x2 - x1), &self.p);
|
|
let x3 = ((&lambda * &lambda) - x1 - x2).mod_floor(&self.p);
|
|
let y3: BigInt = ((&lambda * (x1 - &x3)) - y1).mod_floor(&self.p);
|
|
|
|
Point::Point { x: x3, y: y3 }
|
|
}
|
|
}
|
|
(Point::Point { x, y }, Point::Infinity) | (Point::Infinity, Point::Point { x, y }) => {
|
|
Point::Point {
|
|
x: x.clone(),
|
|
y: y.clone(),
|
|
}
|
|
}
|
|
(Point::Infinity, Point::Infinity) => Point::Infinity,
|
|
}
|
|
}
|
|
|
|
pub fn multiply_point(&self, s: &BigInt, point: &Point) -> Point {
|
|
let mut res = Point::Infinity;
|
|
let mut temp = point.clone();
|
|
|
|
let mut s = s.clone();
|
|
|
|
while s > BigInt::zero() {
|
|
if (&s % BigInt::from(2)) == BigInt::one() {
|
|
res = self.add_points(&res, &temp);
|
|
}
|
|
temp = self.double_point(&temp);
|
|
|
|
s >>= 1;
|
|
}
|
|
|
|
res
|
|
}
|
|
}
|