From 0c01af944a3b0b7a0c9c4e4ffbd1b2aacba3fc9f Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Wed, 25 Sep 2024 13:39:46 -0600 Subject: [PATCH] incrementalmerkletree-testing: Always rewind to a checkpoint. The previous semantics of the `rewind` operation would remove the last checkpoint, if any, but would not further modify the tree. However, these semantics are error prone - if you rewind to a checkpoint, you are not able to rewind to the same checkpoint again; also, in practice, it doesn't make sense to shift the location of a checkpoint in the note commitment tree. This change alters `rewind` to (a) take an explicit checkpoint depth, with depth `0` meaning that any state added since the last checkpoint should be discarded, and (b) only allow a rewind operation to succeed if a checkpoint actually exists at the specified depth. --- incrementalmerkletree-testing/CHANGELOG.md | 29 +++ .../src/complete_tree.rs | 85 +++++--- incrementalmerkletree-testing/src/lib.rs | 202 +++++++++++------- 3 files changed, 202 insertions(+), 114 deletions(-) diff --git a/incrementalmerkletree-testing/CHANGELOG.md b/incrementalmerkletree-testing/CHANGELOG.md index 3cac157..1d1a9f3 100644 --- a/incrementalmerkletree-testing/CHANGELOG.md +++ b/incrementalmerkletree-testing/CHANGELOG.md @@ -7,5 +7,34 @@ and this project adheres to Rust's notion of ## Unreleased +This release includes a significant refactoring and rework of several methods +of the `incrementalmerkletree_testing::Tree` trait. Please read the notes for +this release carefully as the semantics of important methods have changed. +These changes may require changes to example tests that rely on this crate; in +particular, additional checkpoints may be required in circumstances where +rewind operations are being applied. + +### Added +- `incrementalmerkletree_testing::Tree::checkpoint_count` + +### Changed +- `incrementalmerkletree_testing::Tree` + - `Tree::root` now takes its `checkpoint_depth` argument as `Option` + instead of `usize`. Passing `None` to this method will now compute the root + given all of the leaves in the tree; if a `Some` value is passed, + implementations of this method must treat the wrapped `usize` as a reverse + index into the checkpoints added to the tree, or return `None` if no + checkpoint exists at the specified index. This effectively modifies this + method to use zero-based indexing instead of a muddled 1-based indexing + scheme. + - `Tree::rewind` now takes an additional `checkpoint_depth` argument, which + is non-optional. Rewinding the tree may now only be performed if there is + a checkpoint at the specified depth to rewind to. This depth should be + treated as a zero-based reverse index into the checkpoints of the tree. + Rewinding no longer removes the checkpoint being rewound to; instead, it + removes the effect all state changes to the tree resulting from + operations performed since the checkpoint was created, but leaves the + checkpoint itself in place. + ## [0.1.0] - 2024-09-25 Initial release. diff --git a/incrementalmerkletree-testing/src/complete_tree.rs b/incrementalmerkletree-testing/src/complete_tree.rs index 2e4e2f7..b4f4460 100644 --- a/incrementalmerkletree-testing/src/complete_tree.rs +++ b/incrementalmerkletree-testing/src/complete_tree.rs @@ -156,11 +156,14 @@ impl CompleteTr } } + // Creates a new checkpoint with the specified identifier and the given tree position; if `pos` + // is not provided, the position of the most recently appended leaf is used, or a new + // checkpoint of the empty tree is added if appropriate. fn checkpoint(&mut self, id: C, pos: Option) { self.checkpoints.insert( id, Checkpoint::at_length(pos.map_or_else( - || 0, + || self.leaves.len(), |p| usize::try_from(p).expect(MAX_COMPLETE_SIZE_ERROR) + 1, )), ); @@ -170,16 +173,12 @@ impl CompleteTr } fn leaves_at_checkpoint_depth(&self, checkpoint_depth: usize) -> Option { - if checkpoint_depth == 0 { - Some(self.leaves.len()) - } else { - self.checkpoints - .iter() - .rev() - .skip(checkpoint_depth - 1) - .map(|(_, c)| c.leaves_len) - .next() - } + self.checkpoints + .iter() + .rev() + .skip(checkpoint_depth) + .map(|(_, c)| c.leaves_len) + .next() } /// Removes the oldest checkpoint. Returns true if successful and false if @@ -237,21 +236,20 @@ impl Option { - self.leaves_at_checkpoint_depth(checkpoint_depth) - .and_then(|len| root(&self.leaves[0..len], DEPTH)) + fn root(&self, checkpoint_depth: Option) -> Option { + checkpoint_depth.map_or_else( + || root(&self.leaves[..], DEPTH), + |depth| { + self.leaves_at_checkpoint_depth(depth) + .and_then(|len| root(&self.leaves[0..len], DEPTH)) + }, + ) } fn witness(&self, position: Position, checkpoint_depth: usize) -> Option> { - if self.marks.contains(&position) && checkpoint_depth <= self.checkpoints.len() { + if self.marks.contains(&position) { let leaves_len = self.leaves_at_checkpoint_depth(checkpoint_depth)?; - let c_idx = self.checkpoints.len() - checkpoint_depth; - if self - .checkpoints - .iter() - .skip(c_idx) - .any(|(_, c)| c.marked.contains(&position)) - { + if u64::from(position) >= u64::try_from(leaves_len).unwrap() { // The requested position was marked after the checkpoint was created, so we // cannot create a witness. None @@ -299,14 +297,35 @@ impl bool { - if let Some((id, c)) = self.checkpoints.iter().rev().next() { - self.leaves.truncate(c.leaves_len); - for pos in c.marked.iter() { - self.marks.remove(pos); + fn checkpoint_count(&self) -> usize { + self.checkpoints.len() + } + + fn rewind(&mut self, depth: usize) -> bool { + if self.checkpoints.len() > depth { + let mut to_delete = vec![]; + for (idx, (id, c)) in self + .checkpoints + .iter_mut() + .rev() + .enumerate() + .take(depth + 1) + { + for pos in c.marked.iter() { + self.marks.remove(pos); + } + if idx < depth { + to_delete.push(id.clone()); + } else { + self.leaves.truncate(c.leaves_len); + c.marked.clear(); + c.forgotten.clear(); + } } - let id = id.clone(); // needed to avoid mutable/immutable borrow conflict - self.checkpoints.remove(&id); + for cid in to_delete.iter() { + self.checkpoints.remove(cid); + } + true } else { false @@ -334,7 +353,7 @@ mod tests { } let tree = CompleteTree::::new(100); - assert_eq!(tree.root(0).unwrap(), expected); + assert_eq!(tree.root(None), Some(expected)); } #[test] @@ -362,7 +381,7 @@ mod tests { ), ); - assert_eq!(tree.root(0).unwrap(), expected); + assert_eq!(tree.root(None), Some(expected)); } #[test] @@ -408,7 +427,9 @@ mod tests { ), ); - assert_eq!(tree.root(0).unwrap(), expected); + assert_eq!(tree.root(None), Some(expected.clone())); + tree.checkpoint((), None); + assert_eq!(tree.root(Some(0)), Some(expected.clone())); for i in 0u64..(1 << DEPTH) { let position = Position::try_from(i).unwrap(); diff --git a/incrementalmerkletree-testing/src/lib.rs b/incrementalmerkletree-testing/src/lib.rs index 2641aa2..b9fc02d 100644 --- a/incrementalmerkletree-testing/src/lib.rs +++ b/incrementalmerkletree-testing/src/lib.rs @@ -34,11 +34,13 @@ pub trait Tree { /// Return a set of all the positions for which we have marked. fn marked_positions(&self) -> BTreeSet; - /// Obtains the root of the Merkle tree at the specified checkpoint depth - /// by hashing against empty nodes up to the maximum height of the tree. - /// Returns `None` if there are not enough checkpoints available to reach the - /// requested checkpoint depth. - fn root(&self, checkpoint_depth: usize) -> Option; + /// Obtains the root of the Merkle tree at the specified checkpoint depth by hashing against + /// empty nodes up to the maximum height of the tree. If the provided checkpoint depth is + /// `None`, the root is computed for all leaves in the tree. + /// + /// Returns `None` if a checkpoint depth is provided but if there are not enough checkpoints + /// available to reach the requested checkpoint depth. + fn root(&self, checkpoint_depth: Option) -> Option; /// Obtains a witness for the value at the specified leaf position, as of the tree state at the /// given checkpoint depth. Returns `None` if there is no witness information for the requested @@ -57,12 +59,19 @@ pub trait Tree { /// less than or equal to the maximum checkpoint identifier observed. fn checkpoint(&mut self, id: C) -> bool; - /// Rewinds the tree state to the previous checkpoint, and then removes that checkpoint record. + /// Returns the number of checkpoints in the tree. + fn checkpoint_count(&self) -> usize; + + /// Rewinds the tree state to the checkpoint at the specified depth. + /// + /// A depth value of `0` removes all data added to the tree after the most recently-added + /// checkpoint. The checkpoint at the specified depth is retained, but all data and metadata + /// related to operations on the tree that occurred after the checkpoint was created is + /// discarded. /// - /// If there are multiple checkpoints at a given tree state, the tree state will not be altered - /// until all checkpoints at that tree state have been removed using `rewind`. This function - /// will return false and leave the tree unmodified if no checkpoints exist. - fn rewind(&mut self) -> bool; + /// Returns `true` if the request was satisfied, or `false` if insufficient checkpoints were + /// available to satisfy the request. + fn rewind(&mut self, checkpoint_depth: usize) -> bool; } // @@ -100,7 +109,7 @@ pub enum Operation { MarkedPositions, Unmark(Position), Checkpoint(C), - Rewind, + Rewind(usize), Witness(Position, usize), GarbageCollect, } @@ -137,8 +146,8 @@ impl Operation { tree.checkpoint(id.clone()); None } - Rewind => { - assert!(tree.rewind(), "rewind failed"); + Rewind(depth) => { + assert_eq!(tree.checkpoint_count() > *depth, tree.rewind(*depth)); None } Witness(p, d) => tree.witness(*p, *d).map(|xs| (*p, xs)), @@ -165,7 +174,7 @@ impl Operation { MarkedPositions => MarkedPositions, Unmark(p) => Unmark(*p), Checkpoint(id) => Checkpoint(f(id)), - Rewind => Rewind, + Rewind(depth) => Rewind(*depth), Witness(p, d) => Witness(*p, *d), GarbageCollect => GarbageCollect, } @@ -209,7 +218,7 @@ where pos_gen.clone().prop_map(Operation::MarkedLeaf), pos_gen.clone().prop_map(Operation::Unmark), Just(Operation::Checkpoint(())), - Just(Operation::Rewind), + (0usize..10).prop_map(Operation::Rewind), pos_gen.prop_flat_map(|i| (0usize..10).prop_map(move |depth| Operation::Witness(i, depth))), ] } @@ -225,8 +234,8 @@ pub fn apply_operation>(tree: &mut T, op: Operation) { Checkpoint(id) => { tree.checkpoint(id); } - Rewind => { - tree.rewind(); + Rewind(depth) => { + tree.rewind(depth); } CurrentPosition => {} Witness(_, _) => {} @@ -283,40 +292,43 @@ pub fn check_operations { - if tree.rewind() { - prop_assert!(!tree_checkpoints.is_empty()); - let checkpointed_tree_size = tree_checkpoints.pop().unwrap(); - tree_values.truncate(checkpointed_tree_size); - tree_size = checkpointed_tree_size; + Rewind(depth) => { + if tree.rewind(*depth) { + let retained = tree_checkpoints.len() - depth; + if *depth > 0 { + // The last checkpoint will have been dropped, and the tree will have been + // truncated to a previous checkpoint. + tree_checkpoints.truncate(retained); + } + + let checkpointed_tree_size = tree_checkpoints + .last() + .expect("at least one checkpoint must exist in order to be able to rewind"); + tree_values.truncate(*checkpointed_tree_size); + tree_size = *checkpointed_tree_size; } } Witness(position, depth) => { if let Some(path) = tree.witness(*position, *depth) { let value: H = tree_values[::try_from(*position).unwrap()].clone(); - let tree_root = tree.root(*depth); - - if tree_checkpoints.len() >= *depth { - let mut extended_tree_values = tree_values.clone(); - if *depth > 0 { - // prune the tree back to the checkpointed size. - if let Some(checkpointed_tree_size) = - tree_checkpoints.get(tree_checkpoints.len() - depth) - { - extended_tree_values.truncate(*checkpointed_tree_size); - } - } - - // compute the root - let expected_root = - complete_tree::root::(&extended_tree_values, tree.depth()); - prop_assert_eq!(&tree_root.unwrap(), &expected_root); - - prop_assert_eq!( - &compute_root_from_witness(value, *position, &path), - &expected_root - ); - } + let tree_root = tree.root(Some(*depth)).expect( + "we must be able to compute a root anywhere we can compute a witness.", + ); + let mut extended_tree_values = tree_values.clone(); + // prune the tree back to the checkpointed size. + let checkpointed_tree_size = + tree_checkpoints[tree_checkpoints.len() - (depth + 1)]; + extended_tree_values.truncate(checkpointed_tree_size); + + // compute the root + let expected_root = + complete_tree::root::(&extended_tree_values, tree.depth()); + prop_assert_eq!(&tree_root, &expected_root); + + prop_assert_eq!( + &compute_root_from_witness(value, *position, &path), + &expected_root + ); } } GarbageCollect => {} @@ -382,7 +394,7 @@ impl, E: Tree> a } - fn root(&self, checkpoint_depth: usize) -> Option { + fn root(&self, checkpoint_depth: Option) -> Option { let a = self.inefficient.root(checkpoint_depth); let b = self.efficient.root(checkpoint_depth); assert_eq!(a, b); @@ -431,9 +443,16 @@ impl, E: Tree> a } - fn rewind(&mut self) -> bool { - let a = self.inefficient.rewind(); - let b = self.efficient.rewind(); + fn checkpoint_count(&self) -> usize { + let a = self.inefficient.checkpoint_count(); + let b = self.efficient.checkpoint_count(); + assert_eq!(a, b); + a + } + + fn rewind(&mut self, checkpoint_depth: usize) -> bool { + let a = self.inefficient.rewind(checkpoint_depth); + let b = self.efficient.rewind(checkpoint_depth); assert_eq!(a, b); a } @@ -478,7 +497,7 @@ impl TestCheckpoint for usize { } trait TestTree { - fn assert_root(&self, checkpoint_depth: usize, values: &[u64]); + fn assert_root(&self, checkpoint_depth: Option, values: &[u64]); fn assert_append(&mut self, value: u64, retention: Retention); @@ -486,7 +505,7 @@ trait TestTree { } impl> TestTree for T { - fn assert_root(&self, checkpoint_depth: usize, values: &[u64]) { + fn assert_root(&self, checkpoint_depth: Option, values: &[u64]) { assert_eq!( self.root(checkpoint_depth).unwrap(), H::combine_all(self.depth(), values) @@ -518,13 +537,13 @@ pub fn check_root_hashes, F: F { let mut tree = new_tree(100); - tree.assert_root(0, &[]); + tree.assert_root(None, &[]); tree.assert_append(0, Ephemeral); - tree.assert_root(0, &[0]); + tree.assert_root(None, &[0]); tree.assert_append(1, Ephemeral); - tree.assert_root(0, &[0, 1]); + tree.assert_root(None, &[0, 1]); tree.assert_append(2, Ephemeral); - tree.assert_root(0, &[0, 1, 2]); + tree.assert_root(None, &[0, 1, 2]); } { @@ -539,7 +558,7 @@ pub fn check_root_hashes, F: F for _ in 0..3 { t.assert_append(0, Ephemeral); } - t.assert_root(0, &[0, 0, 0, 0]); + t.assert_root(None, &[0, 0, 0, 0]); } } @@ -584,12 +603,14 @@ pub fn check_witnesses, F: Fn( let mut tree = new_tree(100); tree.assert_append(0, Ephemeral); tree.assert_append(1, Marked); + tree.checkpoint(C::from_u64(0)); assert_eq!(tree.witness(Position::from(0), 0), None); } { let mut tree = new_tree(100); tree.assert_append(0, Marked); + tree.checkpoint(C::from_u64(0)); assert_eq!( tree.witness(Position::from(0), 0), Some(vec![ @@ -601,6 +622,7 @@ pub fn check_witnesses, F: Fn( ); tree.assert_append(1, Ephemeral); + tree.checkpoint(C::from_u64(1)); assert_eq!( tree.witness(0.into(), 0), Some(vec![ @@ -612,6 +634,7 @@ pub fn check_witnesses, F: Fn( ); tree.assert_append(2, Marked); + tree.checkpoint(C::from_u64(2)); assert_eq!( tree.witness(Position::from(2), 0), Some(vec![ @@ -623,6 +646,7 @@ pub fn check_witnesses, F: Fn( ); tree.assert_append(3, Ephemeral); + tree.checkpoint(C::from_u64(3)); assert_eq!( tree.witness(Position::from(2), 0), Some(vec![ @@ -634,6 +658,7 @@ pub fn check_witnesses, F: Fn( ); tree.assert_append(4, Ephemeral); + tree.checkpoint(C::from_u64(4)); assert_eq!( tree.witness(Position::from(2), 0), Some(vec![ @@ -653,6 +678,7 @@ pub fn check_witnesses, F: Fn( } tree.assert_append(6, Marked); tree.assert_append(7, Ephemeral); + tree.checkpoint(C::from_u64(0)); assert_eq!( tree.witness(0.into(), 0), @@ -674,6 +700,7 @@ pub fn check_witnesses, F: Fn( tree.assert_append(4, Marked); tree.assert_append(5, Marked); tree.assert_append(6, Ephemeral); + tree.checkpoint(C::from_u64(0)); assert_eq!( tree.witness(Position::from(5), 0), @@ -693,6 +720,7 @@ pub fn check_witnesses, F: Fn( } tree.assert_append(10, Marked); tree.assert_append(11, Ephemeral); + tree.checkpoint(C::from_u64(0)); assert_eq!( tree.witness(Position::from(10), 0), @@ -714,7 +742,7 @@ pub fn check_witnesses, F: Fn( marking: Marking::Marked, }, ); - assert!(tree.rewind()); + assert!(tree.rewind(0)); for i in 1..4 { tree.assert_append(i, Ephemeral); } @@ -722,6 +750,7 @@ pub fn check_witnesses, F: Fn( for i in 5..8 { tree.assert_append(i, Ephemeral); } + tree.checkpoint(C::from_u64(2)); assert_eq!( tree.witness(0.into(), 0), Some(vec![ @@ -749,7 +778,7 @@ pub fn check_witnesses, F: Fn( }, ); tree.assert_append(7, Ephemeral); - assert!(tree.rewind()); + assert!(tree.rewind(0)); assert_eq!( tree.witness(Position::from(2), 0), Some(vec![ @@ -770,6 +799,7 @@ pub fn check_witnesses, F: Fn( tree.assert_append(13, Marked); tree.assert_append(14, Ephemeral); tree.assert_append(15, Ephemeral); + tree.checkpoint(C::from_u64(0)); assert_eq!( tree.witness(Position::from(12), 0), @@ -786,7 +816,13 @@ pub fn check_witnesses, F: Fn( let ops = (0..=11) .map(|i| Append(H::from_u64(i), Marked)) .chain(Some(Append(H::from_u64(12), Ephemeral))) - .chain(Some(Append(H::from_u64(13), Ephemeral))) + .chain(Some(Append( + H::from_u64(13), + Checkpoint { + id: C::from_u64(0), + marking: Marking::None, + }, + ))) .chain(Some(Witness(11u64.into(), 0))) .collect::>(); @@ -840,7 +876,7 @@ pub fn check_witnesses, F: Fn( marking: Marking::None, }, ), - Witness(3u64.into(), 5), + Witness(3u64.into(), 4), ]; let mut tree = new_tree(100); assert_eq!( @@ -881,7 +917,7 @@ pub fn check_witnesses, F: Fn( ), Append(H::from_u64(0), Ephemeral), Append(H::from_u64(0), Ephemeral), - Witness(Position::from(3), 1), + Witness(Position::from(3), 0), ]; let mut tree = new_tree(100); assert_eq!( @@ -918,8 +954,7 @@ pub fn check_witnesses, F: Fn( marking: Marking::None, }, ), - Rewind, - Rewind, + Rewind(2), Witness(Position::from(7), 2), ]; let mut tree = new_tree(100); @@ -944,7 +979,7 @@ pub fn check_witnesses, F: Fn( marking: Marking::None, }, ), - Witness(Position::from(2), 2), + Witness(Position::from(2), 1), ]; let mut tree = new_tree(100); assert_eq!( @@ -966,40 +1001,43 @@ pub fn check_checkpoint_rewind, F: Fn(usiz new_tree: F, ) { let mut t = new_tree(100); - assert!(!t.rewind()); + assert!(!t.rewind(0)); let mut t = new_tree(100); t.assert_checkpoint(1); - assert!(t.rewind()); + assert!(t.rewind(0)); + assert!(!t.rewind(1)); let mut t = new_tree(100); t.append("a".to_string(), Retention::Ephemeral); t.assert_checkpoint(1); t.append("b".to_string(), Retention::Marked); - assert!(t.rewind()); + assert_eq!(Some(Position::from(1)), t.current_position()); + assert!(t.rewind(0)); assert_eq!(Some(Position::from(0)), t.current_position()); let mut t = new_tree(100); t.append("a".to_string(), Retention::Marked); t.assert_checkpoint(1); - assert!(t.rewind()); + assert!(t.rewind(0)); + assert_eq!(Some(Position::from(0)), t.current_position()); let mut t = new_tree(100); t.append("a".to_string(), Retention::Marked); t.assert_checkpoint(1); t.append("a".to_string(), Retention::Ephemeral); - assert!(t.rewind()); + assert!(t.rewind(0)); assert_eq!(Some(Position::from(0)), t.current_position()); let mut t = new_tree(100); t.append("a".to_string(), Retention::Ephemeral); t.assert_checkpoint(1); t.assert_checkpoint(2); - assert!(t.rewind()); + assert!(t.rewind(1)); t.append("b".to_string(), Retention::Ephemeral); - assert!(t.rewind()); + assert!(t.rewind(0)); t.append("b".to_string(), Retention::Ephemeral); - assert_eq!(t.root(0).unwrap(), "ab______________"); + assert_eq!(t.root(None).unwrap(), "ab______________"); } pub fn check_remove_mark, F: Fn(usize) -> T>(new_tree: F) { @@ -1044,7 +1082,7 @@ pub fn check_rewind_remove_mark, F: Fn(usi let mut tree = new_tree(100); tree.append("e".to_string(), Retention::Marked); tree.assert_checkpoint(1); - assert!(tree.rewind()); + assert!(tree.rewind(0)); assert!(tree.remove_mark(0u64.into())); // use a maximum number of checkpoints of 1 @@ -1071,14 +1109,14 @@ pub fn check_rewind_remove_mark, F: Fn(usi vec![ append_str("x", Retention::Marked), Checkpoint(C::from_u64(1)), - Rewind, + Rewind(0), unmark(0), ], vec![ append_str("d", Retention::Marked), Checkpoint(C::from_u64(1)), unmark(0), - Rewind, + Rewind(0), unmark(0), ], vec![ @@ -1086,22 +1124,22 @@ pub fn check_rewind_remove_mark, F: Fn(usi Checkpoint(C::from_u64(1)), Checkpoint(C::from_u64(2)), unmark(0), - Rewind, - Rewind, + Rewind(0), + Rewind(1), ], vec![ append_str("s", Retention::Marked), append_str("m", Retention::Ephemeral), Checkpoint(C::from_u64(1)), unmark(0), - Rewind, + Rewind(0), unmark(0), unmark(0), ], vec![ append_str("a", Retention::Marked), Checkpoint(C::from_u64(1)), - Rewind, + Rewind(0), append_str("a", Retention::Marked), ], ]; @@ -1194,7 +1232,7 @@ pub fn check_witness_consistency, F: Fn(us Checkpoint(C::from_u64(1)), unmark(0), Checkpoint(C::from_u64(2)), - Rewind, + Rewind(0), append_str("b", Retention::Ephemeral), witness(0, 0), ], @@ -1202,7 +1240,7 @@ pub fn check_witness_consistency, F: Fn(us append_str("a", Retention::Marked), Checkpoint(C::from_u64(1)), Checkpoint(C::from_u64(2)), - Rewind, + Rewind(1), append_str("a", Retention::Ephemeral), unmark(0), witness(0, 1),