diff --git a/src/secp256k1/curve.rs b/src/secp256k1/curve.rs index b295a5cd..f87deb8a 100644 --- a/src/secp256k1/curve.rs +++ b/src/secp256k1/curve.rs @@ -275,63 +275,4 @@ crate::curve_testing_suite!(Secp256k1); crate::curve_testing_suite!(Secp256k1, "endo_consistency"); #[cfg(test)] -mod extra_tests { - use super::*; - use ff::FromUniformBytes; - use rand_core::OsRng; - - #[test] - fn ecdsa_example() { - fn mod_n(x: Fp) -> Fq { - let mut x_repr = [0u8; 32]; - x_repr.copy_from_slice(x.to_repr().as_ref()); - let mut x_bytes = [0u8; 64]; - x_bytes[..32].copy_from_slice(&x_repr[..]); - Fq::from_uniform_bytes(&x_bytes) - } - - let g = Secp256k1::generator(); - - for _ in 0..1000 { - // Generate a key pair - let sk = Fq::random(OsRng); - let pk = (g * sk).to_affine(); - - // Generate a valid signature - // Suppose `m_hash` is the message hash - let msg_hash = Fq::random(OsRng); - - let (r, s) = { - // Draw arandomness - let k = Fq::random(OsRng); - let k_inv = k.invert().unwrap(); - - // Calculate `r` - let r_point = (g * k).to_affine().coordinates().unwrap(); - let x = r_point.x(); - let r = mod_n(*x); - - // Calculate `s` - let s = k_inv * (msg_hash + (r * sk)); - - (r, s) - }; - - { - // Verify - let s_inv = s.invert().unwrap(); - let u_1 = msg_hash * s_inv; - let u_2 = r * s_inv; - - let v_1 = g * u_1; - let v_2 = pk * u_2; - - let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); - let x_candidate = r_point.x(); - let r_candidate = mod_n(*x_candidate); - - assert_eq!(r, r_candidate); - } - } - } -} +crate::curve_testing_suite!(Secp256k1, "ecdsa_example"); diff --git a/src/secp256r1/curve.rs b/src/secp256r1/curve.rs index b225f313..0b71b1f1 100644 --- a/src/secp256r1/curve.rs +++ b/src/secp256r1/curve.rs @@ -95,65 +95,4 @@ impl Secp256r1 { crate::curve_testing_suite!(Secp256r1); #[cfg(test)] -mod extra_tests { - use super::*; - use crate::group::Curve; - use crate::secp256r1::{Fp, Fq, Secp256r1}; - use ff::FromUniformBytes; - use rand_core::OsRng; - - #[test] - fn ecdsa_example() { - fn mod_n(x: Fp) -> Fq { - let mut x_repr = [0u8; 32]; - x_repr.copy_from_slice(x.to_repr().as_ref()); - let mut x_bytes = [0u8; 64]; - x_bytes[..32].copy_from_slice(&x_repr[..]); - Fq::from_uniform_bytes(&x_bytes) - } - - let g = Secp256r1::generator(); - - for _ in 0..1000 { - // Generate a key pair - let sk = Fq::random(OsRng); - let pk = (g * sk).to_affine(); - - // Generate a valid signature - // Suppose `m_hash` is the message hash - let msg_hash = Fq::random(OsRng); - - let (r, s) = { - // Draw arandomness - let k = Fq::random(OsRng); - let k_inv = k.invert().unwrap(); - - // Calculate `r` - let r_point = (g * k).to_affine().coordinates().unwrap(); - let x = r_point.x(); - let r = mod_n(*x); - - // Calculate `s` - let s = k_inv * (msg_hash + (r * sk)); - - (r, s) - }; - - { - // Verify - let s_inv = s.invert().unwrap(); - let u_1 = msg_hash * s_inv; - let u_2 = r * s_inv; - - let v_1 = g * u_1; - let v_2 = pk * u_2; - - let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); - let x_candidate = r_point.x(); - let r_candidate = mod_n(*x_candidate); - - assert_eq!(r, r_candidate); - } - } - } -} +crate::curve_testing_suite!(Secp256r1, "ecdsa_example"); diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 5fda91de..f04c1228 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -465,4 +465,64 @@ macro_rules! curve_testing_suite { } } }; + + ($curve: ident, "ecdsa_example") => { + #[test] + fn ecdsa_example() { + use ff::FromUniformBytes; + use rand_core::OsRng; + + fn mod_n(x: <$curve as CurveExt>::Base) -> <$curve as CurveExt>::ScalarExt { + let mut x_repr = [0u8; 32]; + x_repr.copy_from_slice(x.to_repr().as_ref()); + let mut x_bytes = [0u8; 64]; + x_bytes[..32].copy_from_slice(&x_repr[..]); + <$curve as CurveExt>::ScalarExt::from_uniform_bytes(&x_bytes) + } + + let g = $curve::generator(); + + for _ in 0..1000 { + // Generate a key pair + let sk = <$curve as CurveExt>::ScalarExt::random(OsRng); + let pk = (g * sk).to_affine(); + + // Generate a valid signature + // Suppose `m_hash` is the message hash + let msg_hash = <$curve as CurveExt>::ScalarExt::random(OsRng); + + let (r, s) = { + // Draw arandomness + let k = <$curve as CurveExt>::ScalarExt::random(OsRng); + let k_inv = k.invert().unwrap(); + + // Calculate `r` + let r_point = (g * k).to_affine().coordinates().unwrap(); + let x = r_point.x(); + let r = mod_n(*x); + + // Calculate `s` + let s = k_inv * (msg_hash + (r * sk)); + + (r, s) + }; + + { + // Verify + let s_inv = s.invert().unwrap(); + let u_1 = msg_hash * s_inv; + let u_2 = r * s_inv; + + let v_1 = g * u_1; + let v_2 = pk * u_2; + + let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); + let x_candidate = r_point.x(); + let r_candidate = mod_n(*x_candidate); + + assert_eq!(r, r_candidate); + } + } + } + } }