From 04a1328947624975e8e69daba9de3035e46ea08a Mon Sep 17 00:00:00 2001 From: bhargav Date: Sat, 21 Oct 2023 15:32:38 -0700 Subject: [PATCH] feat: radix2 fft --- src/ntt.rs | 87 +++++++++++++++++++++++++---------------------- src/polynomial.rs | 43 ++++++++++++++--------- 2 files changed, 73 insertions(+), 57 deletions(-) diff --git a/src/ntt.rs b/src/ntt.rs index 2ddaf60..f5a3ea2 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -10,23 +10,6 @@ pub struct Constants { pub w: i64, } -fn gcd(a: i64, b: i64) -> i64 { - let mut a = a; - let mut b = b; - - if a < b { - swap(&mut a, &mut b); - } - if a % b == 0 { - return b; - } - while b > 0 { - a = a % b; - swap(&mut a, &mut b); - } - a -} - fn extended_gcd(a: i64, b: i64) -> i64 { let mut a = a; let mut b = b; @@ -95,25 +78,50 @@ pub fn working_modulus(n: i64, M: i64) -> Constants { Constants { k, N, w } } -pub fn forward(inp: Vec, c: &Constants) -> Vec { - let mut pre = vec![-1; inp.len().pow(2)]; - (0..inp.len()).for_each(|col| { - (0..=col).for_each(|row| { - if pre[row * inp.len() + col] == -1 { - pre[row * inp.len() + col] = mod_exp(c.w, (row * col) as i64, c.N) as i64; - } - }) +fn order_reverse(inp: &mut Vec) { + let mut j = 0; + let n = inp.len(); + (1..n).for_each(|i| { + let mut bit = n >> 1; + while (j & bit) > 0 { + j ^= bit; + bit >>= 1; + } + j ^= bit; + + if i < j { + inp.swap(i, j); + } }); +} - (0..inp.len()) - .map(|k| { - inp.iter().enumerate().fold(0, |acc, (i, cur)| { - let row = k.min(i); - let col = k.max(i); - (acc + cur * pre[row * inp.len() + col]) % c.N as i64 - }) % c.N as i64 - }) - .collect() +pub fn forward(inp: Vec, c: &Constants) -> Vec { + let mut inp = inp.clone(); + let N = inp.len(); + let mut pre = vec![1; N / 2]; + + (1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * c.w).rem_euclid(c.N)); + order_reverse(&mut inp); + + let mut len = 2; + + while len <= N { + let half = len / 2; + let pre_step = N / len; + (0..N).step_by(len).for_each(|i| { + let mut k = 0; + (i..i + half).for_each(|j| { + let l = j + half; + let left = inp[j]; + let right = inp[l] * pre[k]; + inp[j] = (left + right).rem_euclid(c.N); + inp[l] = (left - right).rem_euclid(c.N); + k += pre_step; + }) + }); + len <<= 1; + } + inp } pub fn inverse(inp: Vec, c: &Constants) -> Vec { @@ -146,10 +154,12 @@ pub fn inverse(inp: Vec, c: &Constants) -> Vec { mod tests { use rand::Rng; - use crate::ntt::{extended_gcd, forward, gcd, inverse, working_modulus, Constants}; + use crate::ntt::{extended_gcd, forward, inverse, working_modulus, Constants}; + #[test] fn test_forward() { - let n = rand::thread_rng().gen::().abs() % 10; + // let n = rand::thread_rng().gen::().abs() % 10; + let n = 8; let v: Vec = (0..n) .map(|_| rand::thread_rng().gen::().abs() % (1 << 6)) .collect(); @@ -166,9 +176,4 @@ mod tests { assert_eq!((x * inv) % 11, 1); }); } - #[test] - fn test_gcd() { - assert_eq!(gcd(10, 5), 5); - assert_eq!(gcd(10, 7), 1); - } } diff --git a/src/polynomial.rs b/src/polynomial.rs index 51d5489..5f30515 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -13,6 +13,21 @@ pub struct Polynomial { } impl Polynomial { + pub fn new(coef: Vec) -> Self { + let n = coef.len(); + + // if is not power of 2 + if !(n & (n - 1) == 0) { + let pad = n.next_power_of_two() - n; + return Self { + coef: vec![0; pad] + .into_iter() + .chain(coef.into_iter()) + .collect_vec(), + }; + } + Self { coef } + } pub fn diff(mut self) -> Self { let N = self.coef.len(); for n in (1..N).rev() { @@ -70,7 +85,7 @@ impl Mul for Polynomial { fn mul(self, rhs: Polynomial) -> Self::Output { let mut v1 = self.coef; let mut v2 = rhs.coef; - let n = (v1.len() + v2.len()) as i64; + let n = (v1.len() + v2.len()).next_power_of_two() as i64; let v1_deg = v1.len() - 1; let v2_deg = v2.len() - 1; @@ -123,9 +138,13 @@ impl Mul for Polynomial { let coef = inverse(mul, &c) .iter() .map(|&x| if x > M / 2 { -(M - x.rem_euclid(M)) } else { x }) - .collect::>()[..=(v1_deg + v2_deg)] + .collect::>() .to_vec(); - Polynomial { coef } + let start = coef.iter().position(|&x| x != 0).unwrap(); + + Polynomial { + coef: coef[start..(start + v1_deg + v2_deg)].to_vec(), + } } } @@ -135,29 +154,21 @@ mod tests { #[test] fn add() { - let a = Polynomial { - coef: vec![1, 2, 3, 4], - }; - let b = Polynomial { coef: vec![1, 2] }; + let a = Polynomial::new(vec![1, 2, 3, 4]); + let b = Polynomial::new(vec![1, 2]); println!("{:?}", a + b); } #[test] fn mul() { - let a = Polynomial { - coef: vec![1, 2, -3], - }; - let b = Polynomial { - coef: vec![1, -5, 4, -8], - }; + let a = Polynomial::new(vec![1, -2, 3]); + let b = Polynomial::new(vec![1, -3]); println!("{:?}", a * b); } #[test] fn diff() { - let a = Polynomial { - coef: vec![3, 2, 1], - }; + let a = Polynomial::new(vec![3, 2, 1]); let da = a.diff(); println!("{:?}", da); }