Skip to content

Commit

Permalink
fix: polynomial mult
Browse files Browse the repository at this point in the history
  • Loading branch information
0xWOLAND committed Oct 20, 2023
1 parent d0c2069 commit 97a47ea
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 10 deletions.
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
itertools = "0.11.0"
mod_exp = "1.0.1"
rand = "0.8.5"
37 changes: 30 additions & 7 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use crate::prime::is_prime;
use mod_exp::mod_exp;
use std::mem::swap;

#[derive(Debug, Clone)]
pub struct Constants {
k: i64,
N: i64,
w: i64,
pub k: i64,
pub N: i64,
pub w: i64,
}

fn gcd(a: i64, b: i64) -> i64 {
Expand Down Expand Up @@ -56,7 +57,25 @@ fn extended_gcd(a: i64, b: i64) -> i64 {
(t2 + n) % n
}

fn working_modulus(n: i64, M: i64) -> Constants {
fn prime_factors(a: i64) -> Vec<i64> {
let mut ans: Vec<i64> = Vec::new();
(2..(((a as f64).sqrt() + 1.) as i64)).for_each(|x| {
if a % x == 0 {
ans.push(x);
}
});
ans
}

fn is_primitive_root(a: i64, deg: i64, N: i64) -> bool {
mod_exp(a, deg, N) == 1
&& prime_factors(deg)
.iter()
.map(|&x| mod_exp(a, deg / x, N) != 1)
.all(|x| x)
}

pub fn working_modulus(n: i64, M: i64) -> Constants {
let mut N = n + 1;
let mut k = 1;
while (!is_prime(N)) || N < M {
Expand All @@ -65,7 +84,7 @@ fn working_modulus(n: i64, M: i64) -> Constants {
}
let mut gen = 0;
for g in 2..N {
if gcd(g, N) == 1 {
if is_primitive_root(g, N - 1, N) {
gen = g;
break;
}
Expand Down Expand Up @@ -102,11 +121,15 @@ pub fn inverse(inp: Vec<i64>, c: &Constants) -> Vec<i64> {

#[cfg(test)]
mod tests {
use rand::Rng;

use crate::ntt::{extended_gcd, forward, gcd, inverse, working_modulus, Constants};
#[test]
fn test_forward() {
let v: Vec<i64> = vec![6, 0, 10, 7, 2];
let n = v.len() as i64;
let n = rand::thread_rng().gen::<i64>().abs() % 10;
let v: Vec<i64> = (0..n)
.map(|_| rand::thread_rng().gen::<i64>().abs() % (1 << 6))
.collect();
let M = v.iter().max().unwrap().pow(2) as i64 * n + 1;
let c = working_modulus(n, M);
let forward = forward(v.clone(), &c);
Expand Down
66 changes: 63 additions & 3 deletions src/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::ops::{Add, Mul, Neg, Sub};

use itertools::{EitherOrBoth::*, Itertools};

use crate::ntt::*;

#[derive(Debug, Clone)]
pub struct Polynomial {
pub coef: Vec<i64>,
}
Expand All @@ -9,7 +14,18 @@ impl Add<Polynomial> for Polynomial {

fn add(self, rhs: Polynomial) -> Self::Output {
Polynomial {
coef: self.coef.iter().zip(rhs.coef).map(|(a, b)| a + b).collect(),
coef: self
.coef
.iter()
.rev()
.zip_longest(rhs.coef.iter().rev())
.map(|p| match p {
Both(&a, &b) => a + b,
Left(&a) => a,
Right(&b) => b,
})
.rev()
.collect(),
}
}
}
Expand All @@ -36,7 +52,51 @@ impl Mul<Polynomial> for Polynomial {
type Output = Polynomial;

fn mul(self, rhs: Polynomial) -> Self::Output {
let m = *self.coef.iter().max().unwrap();
todo!()
let n = self.coef.len() as i64;
let M = self.coef.iter().max().unwrap().pow(2) as i64 * n + 1;
let c = working_modulus(n, M);
println!("consts -- {:?} {}", c, M);
let a_forward = forward(self.coef, &c);
let b_forward = forward(rhs.coef, &c);

println!("a -- {:?}", a_forward);
println!("b -- {:?}", b_forward);

let mul = a_forward
.iter()
.rev()
.zip_longest(b_forward.iter().rev())
.map(|p| match p {
Both(&a, &b) => (a * b) % c.N,
Left(&_a) => 0,
Right(&_b) => 0,
})
.rev()
.collect::<Vec<i64>>();
println!("mul -- {:?}", mul);
Polynomial {
coef: inverse(mul, &c),
}
}
}

#[cfg(test)]
mod tests {
use super::Polynomial;

#[test]
fn add() {
let a = Polynomial {
coef: vec![1, 2, 3, 4],
};
let b = Polynomial { coef: vec![1, 2] };
println!("{:?}", a + b);
}

#[test]
fn mul() {
let a = Polynomial { coef: vec![1, 2] };
let b = Polynomial { coef: vec![0, 1] };
println!("{:?}", a * b);
}
}

0 comments on commit 97a47ea

Please sign in to comment.