diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 193786112b..efd5c43480 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -5,7 +5,7 @@ use arrayvec::ArrayVec; pub use cursor::{Cursor, FilterCursor, Iter}; use std::marker::PhantomData; use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc}; -pub use tree_map::TreeMap; +pub use tree_map::{TreeMap, TreeSet}; #[cfg(test)] const TREE_BASE: usize = 2; diff --git a/crates/sum_tree/src/tree_map.rs b/crates/sum_tree/src/tree_map.rs index 80143aad69..5218d2b4db 100644 --- a/crates/sum_tree/src/tree_map.rs +++ b/crates/sum_tree/src/tree_map.rs @@ -20,6 +20,11 @@ pub struct MapKey(K); #[derive(Clone, Debug, Default)] pub struct MapKeyRef<'a, K>(Option<&'a K>); +#[derive(Clone)] +pub struct TreeSet(TreeMap) +where + K: Clone + Debug + Default + Ord; + impl TreeMap { pub fn from_ordered_entries(entries: impl IntoIterator) -> Self { let tree = SumTree::from_iter( @@ -136,6 +141,32 @@ where } } +impl Default for TreeSet +where + K: Clone + Debug + Default + Ord, +{ + fn default() -> Self { + Self(Default::default()) + } +} + +impl TreeSet +where + K: Clone + Debug + Default + Ord, +{ + pub fn insert(&mut self, key: K) { + self.0.insert(key, ()); + } + + pub fn contains(&self, key: &K) -> bool { + self.0.get(key).is_some() + } + + pub fn iter<'a>(&'a self) -> impl 'a + Iterator { + self.0.iter().map(|(k, _)| k) + } +} + #[cfg(test)] mod tests { use super::*;