Add simple unit test for SumTree::{edit,get}

This commit is contained in:
Antonio Scandurra 2021-04-16 16:26:40 +02:00
parent 457d945376
commit b68b0fce56

View file

@ -337,7 +337,7 @@ impl<T: KeyedItem> SumTree<T> {
return Vec::new(); return Vec::new();
} }
let mut replaced = Vec::new(); let mut removed = Vec::new();
edits.sort_unstable_by_key(|item| item.key()); edits.sort_unstable_by_key(|item| item.key());
*self = { *self = {
@ -362,7 +362,7 @@ impl<T: KeyedItem> SumTree<T> {
if let Some(old_item) = old_item { if let Some(old_item) = old_item {
if old_item.key() == new_key { if old_item.key() == new_key {
replaced.push(old_item.clone()); removed.push(old_item.clone());
cursor.next(); cursor.next();
} }
} }
@ -380,7 +380,7 @@ impl<T: KeyedItem> SumTree<T> {
new_tree new_tree
}; };
replaced removed
} }
pub fn get(&self, key: &T::Key) -> Option<&T> { pub fn get(&self, key: &T::Key) -> Option<&T> {
@ -497,6 +497,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::cmp;
use std::ops::Add; use std::ops::Add;
#[test] #[test]
@ -780,11 +781,33 @@ mod tests {
assert_eq!(cursor.slice(&Count(6), SeekBias::Right).items(), vec![6]); assert_eq!(cursor.slice(&Count(6), SeekBias::Right).items(), vec![6]);
} }
#[test]
fn test_edit() {
let mut tree = SumTree::<u8>::new();
let removed = tree.edit(vec![Edit::Insert(1), Edit::Insert(2), Edit::Insert(0)]);
assert_eq!(tree.items(), vec![0, 1, 2]);
assert_eq!(removed, Vec::<u8>::new());
assert_eq!(tree.get(&0), Some(&0));
assert_eq!(tree.get(&1), Some(&1));
assert_eq!(tree.get(&2), Some(&2));
assert_eq!(tree.get(&4), None);
let removed = tree.edit(vec![Edit::Insert(2), Edit::Insert(4), Edit::Remove(0)]);
assert_eq!(tree.items(), vec![1, 2, 4]);
assert_eq!(removed, vec![0, 2]);
assert_eq!(tree.get(&0), None);
assert_eq!(tree.get(&1), Some(&1));
assert_eq!(tree.get(&2), Some(&2));
assert_eq!(tree.get(&4), Some(&4));
}
#[derive(Clone, Default, Debug)] #[derive(Clone, Default, Debug)]
pub struct IntegersSummary { pub struct IntegersSummary {
count: Count, count: Count,
sum: Sum, sum: Sum,
contains_even: bool, contains_even: bool,
max: u8,
} }
#[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)] #[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)]
@ -801,15 +824,31 @@ mod tests {
count: Count(1), count: Count(1),
sum: Sum(*self as usize), sum: Sum(*self as usize),
contains_even: (*self & 1) == 0, contains_even: (*self & 1) == 0,
max: *self,
} }
} }
} }
impl KeyedItem for u8 {
type Key = u8;
fn key(&self) -> Self::Key {
*self
}
}
impl<'a> Dimension<'a, IntegersSummary> for u8 {
fn add_summary(&mut self, summary: &IntegersSummary) {
*self = summary.max;
}
}
impl<'a> AddAssign<&'a Self> for IntegersSummary { impl<'a> AddAssign<&'a Self> for IntegersSummary {
fn add_assign(&mut self, other: &Self) { fn add_assign(&mut self, other: &Self) {
self.count.0 += &other.count.0; self.count.0 += &other.count.0;
self.sum.0 += &other.sum.0; self.sum.0 += &other.sum.0;
self.contains_even |= other.contains_even; self.contains_even |= other.contains_even;
self.max = cmp::max(self.max, other.max);
} }
} }
@ -819,15 +858,6 @@ mod tests {
} }
} }
// impl<'a> Add<&'a Self> for Count {
// type Output = Self;
//
// fn add(mut self, other: &Self) -> Self {
// self.0 += other.0;
// self
// }
// }
impl<'a> Dimension<'a, IntegersSummary> for Sum { impl<'a> Dimension<'a, IntegersSummary> for Sum {
fn add_summary(&mut self, summary: &IntegersSummary) { fn add_summary(&mut self, summary: &IntegersSummary) {
self.0 += summary.sum.0; self.0 += summary.sum.0;