Skip to content

Commit

Permalink
feat: radix2 fft
Browse files Browse the repository at this point in the history
  • Loading branch information
0xWOLAND committed Oct 21, 2023
1 parent af170d8 commit 04a1328
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 57 deletions.
87 changes: 46 additions & 41 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,6 @@ pub struct Constants {
pub w: i64,
}

fn gcd(a: i64, b: i64) -> i64 {
let mut a = a;
let mut b = b;

if a < b {
swap(&mut a, &mut b);
}
if a % b == 0 {
return b;
}
while b > 0 {
a = a % b;
swap(&mut a, &mut b);
}
a
}

fn extended_gcd(a: i64, b: i64) -> i64 {
let mut a = a;
let mut b = b;
Expand Down Expand Up @@ -95,25 +78,50 @@ pub fn working_modulus(n: i64, M: i64) -> Constants {
Constants { k, N, w }
}

pub fn forward(inp: Vec<i64>, c: &Constants) -> Vec<i64> {
let mut pre = vec![-1; inp.len().pow(2)];
(0..inp.len()).for_each(|col| {
(0..=col).for_each(|row| {
if pre[row * inp.len() + col] == -1 {
pre[row * inp.len() + col] = mod_exp(c.w, (row * col) as i64, c.N) as i64;
}
})
fn order_reverse(inp: &mut Vec<i64>) {
let mut j = 0;
let n = inp.len();
(1..n).for_each(|i| {
let mut bit = n >> 1;
while (j & bit) > 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;

if i < j {
inp.swap(i, j);
}
});
}

(0..inp.len())
.map(|k| {
inp.iter().enumerate().fold(0, |acc, (i, cur)| {
let row = k.min(i);
let col = k.max(i);
(acc + cur * pre[row * inp.len() + col]) % c.N as i64
}) % c.N as i64
})
.collect()
pub fn forward(inp: Vec<i64>, c: &Constants) -> Vec<i64> {
let mut inp = inp.clone();
let N = inp.len();
let mut pre = vec![1; N / 2];

(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * c.w).rem_euclid(c.N));
order_reverse(&mut inp);

let mut len = 2;

while len <= N {
let half = len / 2;
let pre_step = N / len;
(0..N).step_by(len).for_each(|i| {
let mut k = 0;
(i..i + half).for_each(|j| {
let l = j + half;
let left = inp[j];
let right = inp[l] * pre[k];
inp[j] = (left + right).rem_euclid(c.N);
inp[l] = (left - right).rem_euclid(c.N);
k += pre_step;
})
});
len <<= 1;
}
inp
}

pub fn inverse(inp: Vec<i64>, c: &Constants) -> Vec<i64> {
Expand Down Expand Up @@ -146,10 +154,12 @@ pub fn inverse(inp: Vec<i64>, c: &Constants) -> Vec<i64> {
mod tests {
use rand::Rng;

use crate::ntt::{extended_gcd, forward, gcd, inverse, working_modulus, Constants};
use crate::ntt::{extended_gcd, forward, inverse, working_modulus, Constants};

#[test]
fn test_forward() {
let n = rand::thread_rng().gen::<i64>().abs() % 10;
// let n = rand::thread_rng().gen::<i64>().abs() % 10;
let n = 8;
let v: Vec<i64> = (0..n)
.map(|_| rand::thread_rng().gen::<i64>().abs() % (1 << 6))
.collect();
Expand All @@ -166,9 +176,4 @@ mod tests {
assert_eq!((x * inv) % 11, 1);
});
}
#[test]
fn test_gcd() {
assert_eq!(gcd(10, 5), 5);
assert_eq!(gcd(10, 7), 1);
}
}
43 changes: 27 additions & 16 deletions src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@ pub struct Polynomial {
}

impl Polynomial {
pub fn new(coef: Vec<i64>) -> Self {
let n = coef.len();

// if is not power of 2
if !(n & (n - 1) == 0) {
let pad = n.next_power_of_two() - n;
return Self {
coef: vec![0; pad]
.into_iter()
.chain(coef.into_iter())
.collect_vec(),
};
}
Self { coef }
}
pub fn diff(mut self) -> Self {
let N = self.coef.len();
for n in (1..N).rev() {
Expand Down Expand Up @@ -70,7 +85,7 @@ impl Mul<Polynomial> for Polynomial {
fn mul(self, rhs: Polynomial) -> Self::Output {
let mut v1 = self.coef;
let mut v2 = rhs.coef;
let n = (v1.len() + v2.len()) as i64;
let n = (v1.len() + v2.len()).next_power_of_two() as i64;
let v1_deg = v1.len() - 1;
let v2_deg = v2.len() - 1;

Expand Down Expand Up @@ -123,9 +138,13 @@ impl Mul<Polynomial> for Polynomial {
let coef = inverse(mul, &c)
.iter()
.map(|&x| if x > M / 2 { -(M - x.rem_euclid(M)) } else { x })
.collect::<Vec<i64>>()[..=(v1_deg + v2_deg)]
.collect::<Vec<i64>>()
.to_vec();
Polynomial { coef }
let start = coef.iter().position(|&x| x != 0).unwrap();

Polynomial {
coef: coef[start..(start + v1_deg + v2_deg)].to_vec(),
}
}
}

Expand All @@ -135,29 +154,21 @@ mod tests {

#[test]
fn add() {
let a = Polynomial {
coef: vec![1, 2, 3, 4],
};
let b = Polynomial { coef: vec![1, 2] };
let a = Polynomial::new(vec![1, 2, 3, 4]);
let b = Polynomial::new(vec![1, 2]);
println!("{:?}", a + b);
}

#[test]
fn mul() {
let a = Polynomial {
coef: vec![1, 2, -3],
};
let b = Polynomial {
coef: vec![1, -5, 4, -8],
};
let a = Polynomial::new(vec![1, -2, 3]);
let b = Polynomial::new(vec![1, -3]);
println!("{:?}", a * b);
}

#[test]
fn diff() {
let a = Polynomial {
coef: vec![3, 2, 1],
};
let a = Polynomial::new(vec![3, 2, 1]);
let da = a.diff();
println!("{:?}", da);
}
Expand Down

0 comments on commit 04a1328

Please sign in to comment.