Skip to content

Commit

Permalink
remove some more async_trait and tokio::test
Browse files Browse the repository at this point in the history
  • Loading branch information
woocash2 committed Jul 26, 2023
1 parent 98bbde3 commit a75cfab
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 76 deletions.
71 changes: 21 additions & 50 deletions consensus/src/runway/backup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ mod tests {
const NODE_ID: NodeIndex = NodeIndex(0);
const N_MEMBERS: NodeCount = NodeCount(4);

async fn produce_units(rounds: usize, session_id: SessionId) -> Vec<Vec<UncheckedSignedUnit>> {
fn produce_units(rounds: usize, session_id: SessionId) -> Vec<Vec<UncheckedSignedUnit>> {
let mut creators = creator_set(N_MEMBERS);
let keychains: Vec<_> = (0..N_MEMBERS.0)
.map(|id| Keychain::new(N_MEMBERS, NodeIndex(id)))
Expand Down Expand Up @@ -346,7 +346,7 @@ mod tests {
units.iter().map(|u| u.encode()).collect()
}

async fn prepare_test<'a>(
fn prepare_test(
encoded_units: Vec<u8>,
) -> (
impl futures::Future,
Expand Down Expand Up @@ -377,7 +377,7 @@ mod tests {
#[tokio::test]
async fn nothing_loaded_nothing_collected_succeeds() {
let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(Vec::new()).await;
prepare_test(Vec::new());

let handle = tokio::spawn(async {
task.await;
Expand All @@ -393,15 +393,11 @@ mod tests {

#[tokio::test]
async fn something_loaded_nothing_collected_succeeds() {
let units: Vec<_> = produce_units(5, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let units: Vec<_> = produce_units(5, SESSION_ID).into_iter().flatten().collect();
let encoded_units = encode_all(units.clone()).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -417,15 +413,11 @@ mod tests {

#[tokio::test]
async fn something_loaded_something_collected_succeeds() {
let units: Vec<_> = produce_units(5, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let units: Vec<_> = produce_units(5, SESSION_ID).into_iter().flatten().collect();
let encoded_units = encode_all(units.clone()).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -442,7 +434,7 @@ mod tests {
#[tokio::test]
async fn nothing_loaded_something_collected_fails() {
let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(Vec::new()).await;
prepare_test(Vec::new());

let handle = tokio::spawn(async {
task.await;
Expand All @@ -458,15 +450,11 @@ mod tests {

#[tokio::test]
async fn loaded_smaller_then_collected_fails() {
let units: Vec<_> = produce_units(3, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let units: Vec<_> = produce_units(3, SESSION_ID).into_iter().flatten().collect();
let encoded_units = encode_all(units.clone()).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -482,15 +470,11 @@ mod tests {

#[tokio::test]
async fn dropped_collection_fails() {
let units: Vec<_> = produce_units(3, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let units: Vec<_> = produce_units(3, SESSION_ID).into_iter().flatten().collect();
let encoded_units = encode_all(units.clone()).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -506,18 +490,14 @@ mod tests {

#[tokio::test]
async fn backup_with_corrupted_encoding_fails() {
let units = produce_units(5, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let units = produce_units(5, SESSION_ID).into_iter().flatten().collect();
let mut unit_encodings = encode_all(units);
let unit2_encoding_len = unit_encodings[2].len();
unit_encodings[2].resize(unit2_encoding_len - 1, 0); // remove the last byte
let encoded_units = unit_encodings.into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);
let handle = tokio::spawn(async {
task.await;
});
Expand All @@ -532,16 +512,12 @@ mod tests {

#[tokio::test]
async fn backup_with_missing_parent_fails() {
let mut units: Vec<_> = produce_units(5, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let mut units: Vec<_> = produce_units(5, SESSION_ID).into_iter().flatten().collect();
units.remove(2); // it is a parent of all units of round 3
let encoded_units = encode_all(units).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);
let handle = tokio::spawn(async {
task.await;
});
Expand All @@ -556,17 +532,13 @@ mod tests {

#[tokio::test]
async fn backup_with_duplicate_unit_succeeds() {
let mut units: Vec<_> = produce_units(5, SESSION_ID)
.await
.into_iter()
.flatten()
.collect();
let mut units: Vec<_> = produce_units(5, SESSION_ID).into_iter().flatten().collect();
let unit2_duplicate = units[2].clone();
units.insert(3, unit2_duplicate);
let encoded_units = encode_all(units.clone()).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -582,11 +554,11 @@ mod tests {

#[tokio::test]
async fn backup_with_units_of_one_creator_fails() {
let units = units_of_creator(produce_units(5, SESSION_ID).await, NodeIndex(NODE_ID.0 + 1));
let units = units_of_creator(produce_units(5, SESSION_ID), NodeIndex(NODE_ID.0 + 1));
let encoded_units = encode_all(units).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand All @@ -603,14 +575,13 @@ mod tests {
#[tokio::test]
async fn backup_with_wrong_session_fails() {
let units = produce_units(5, SESSION_ID + 1)
.await
.into_iter()
.flatten()
.collect();
let encoded_units = encode_all(units).into_iter().flatten().collect();

let (task, loaded_unit_rx, highest_response_tx, starting_round_rx) =
prepare_test(encoded_units).await;
prepare_test(encoded_units);

let handle = tokio::spawn(async {
task.await;
Expand Down
2 changes: 0 additions & 2 deletions consensus/src/testing/byzantine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::{
Recipient, Round, SessionId, Signed, SpawnHandle, TaskHandle,
};
use aleph_bft_mock::{Data, Hash64, Hasher64, Keychain, NetworkHook, Router, Spawner};
use async_trait::async_trait;
use futures::{channel::oneshot, StreamExt};
use log::{debug, error, trace};
use parking_lot::Mutex;
Expand Down Expand Up @@ -223,7 +222,6 @@ impl AlertHook {
}
}

#[async_trait]
impl NetworkHook<NetworkData> for AlertHook {
fn update_state(&mut self, data: &mut NetworkData, sender: NodeIndex, recipient: NodeIndex) {
use crate::{alerts::AlertMessage::*, network::NetworkDataInner::*};
Expand Down
2 changes: 0 additions & 2 deletions consensus/src/testing/unreliable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
Index, NodeCount, NodeIndex, Round, Signed, SpawnHandle,
};
use aleph_bft_mock::{BadSigning, Keychain, NetworkHook, Router, Spawner};
use async_trait::async_trait;
use futures::StreamExt;
use parking_lot::Mutex;
use std::sync::Arc;
Expand Down Expand Up @@ -40,7 +39,6 @@ struct NoteRequest {
requested: Arc<Mutex<bool>>,
}

#[async_trait]
impl NetworkHook<NetworkData> for NoteRequest {
fn update_state(&mut self, data: &mut NetworkData, sender: NodeIndex, _: NodeIndex) {
use NetworkDataInner::Units;
Expand Down
21 changes: 8 additions & 13 deletions crypto/src/signature.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{Index, NodeCount, NodeIndex, NodeMap};
use async_trait::async_trait;
use codec::{Codec, Decode, Encode};
use log::warn;
use std::{fmt::Debug, hash::Hash};
Expand All @@ -18,7 +17,6 @@ impl<T: Debug + Clone + Codec + Send + Sync + Eq + 'static> Signature for T {}
/// The meaning of sign is then to produce a signature `s` using the given private key,
/// and `verify(msg, s, j)` is to verify whether the signature s under the message msg is
/// correct with respect to the public key of the jth node.
#[async_trait]
pub trait Keychain: Index + Clone + Send + Sync + 'static {
type Signature: Signature;

Expand Down Expand Up @@ -430,7 +428,6 @@ mod tests {
Index, Keychain, MultiKeychain, NodeCount, NodeIndex, PartialMultisignature,
PartiallyMultisigned, Signable, SignatureSet, Signed,
};
use async_trait::async_trait;
use codec::{Decode, Encode};
use std::fmt::Debug;

Expand Down Expand Up @@ -460,7 +457,6 @@ mod tests {
}
}

#[async_trait::async_trait]
impl<K: Keychain> Keychain for DefaultMultiKeychain<K> {
type Signature = K::Signature;

Expand Down Expand Up @@ -545,7 +541,6 @@ mod tests {
}
}

#[async_trait]
impl Keychain for TestKeychain {
type Signature = TestSignature;

Expand All @@ -572,8 +567,8 @@ mod tests {
DefaultMultiKeychain::new(keychain)
}

#[tokio::test]
async fn test_valid_signatures() {
#[test]
fn test_valid_signatures() {
let node_count: NodeCount = 7.into();
let keychains: Vec<TestMultiKeychain> = (0_usize..node_count.0)
.map(|i| test_multi_keychain(node_count, i.into()))
Expand All @@ -591,8 +586,8 @@ mod tests {
}
}

#[tokio::test]
async fn test_invalid_signatures() {
#[test]
fn test_invalid_signatures() {
let node_count: NodeCount = 1.into();
let index: NodeIndex = 0.into();
let keychain = test_multi_keychain(node_count, index);
Expand All @@ -607,8 +602,8 @@ mod tests {
);
}

#[tokio::test]
async fn test_incomplete_multisignature() {
#[test]
fn test_incomplete_multisignature() {
let msg = test_message();
let index: NodeIndex = 0.into();
let node_count: NodeCount = 2.into();
Expand All @@ -621,8 +616,8 @@ mod tests {
);
}

#[tokio::test]
async fn test_multisignatures() {
#[test]
fn test_multisignatures() {
let msg = test_message();
let node_count: NodeCount = 7.into();
let keychains: Vec<TestMultiKeychain> = (0..node_count.0)
Expand Down
1 change: 0 additions & 1 deletion fuzz/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ impl<W: Write> SpyingNetworkHook<W> {
}
}

#[async_trait::async_trait]
impl<W: Write + Send> NetworkHook<FuzzNetworkData> for SpyingNetworkHook<W> {
fn update_state(&mut self, data: &mut FuzzNetworkData, _: NodeIndex, recipient: NodeIndex) {
if self.node == recipient {
Expand Down
2 changes: 0 additions & 2 deletions mock/src/crypto/keychain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use aleph_bft_types::{
Index, Keychain as KeychainT, MultiKeychain as MultiKeychainT, NodeCount, NodeIndex,
PartialMultisignature as PartialMultisignatureT, SignatureSet,
};
use async_trait::async_trait;

#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default)]
pub struct Keychain {
Expand Down Expand Up @@ -33,7 +32,6 @@ impl Index for Keychain {
}
}

#[async_trait]
impl KeychainT for Keychain {
type Signature = Signature;

Expand Down
10 changes: 4 additions & 6 deletions mock/src/crypto/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::crypto::{PartialMultisignature, Signature};
use aleph_bft_types::{
Index, Keychain as KeychainT, MultiKeychain as MultiKeychainT, NodeCount, NodeIndex,
};
use async_trait::async_trait;
use codec::{Decode, Encode};
use std::fmt::Debug;

Expand Down Expand Up @@ -34,21 +33,20 @@ impl<T: MK> Index for BadSigning<T> {
}
}

#[async_trait]
impl<T: MK> KeychainT for BadSigning<T> {
type Signature = T::Signature;

fn node_count(&self) -> NodeCount {
self.0.node_count()
}

fn sign(&self, msg: &[u8]) -> Self::Signature {
let signature = self.0.sign(msg);
let mut msg = b"BAD".to_vec();
msg.extend(signature.msg().clone());
Signature::new(msg, signature.index())
}

fn node_count(&self) -> NodeCount {
self.0.node_count()
}

fn verify(&self, msg: &[u8], sgn: &Self::Signature, index: NodeIndex) -> bool {
self.0.verify(msg, sgn, index)
}
Expand Down

0 comments on commit a75cfab

Please sign in to comment.