diff --git a/src/optimizer/core/top_n.rs b/src/optimizer/core/top_n.rs index eab8858c..01b59459 100644 --- a/src/optimizer/core/top_n.rs +++ b/src/optimizer/core/top_n.rs @@ -134,11 +134,19 @@ impl ColumnTopN { } } - fn insert_new_with_options(&mut self, capacity: usize, entry: ColumnTopNEntry) { + fn insert_new_with_options(&mut self, capacity: usize, mut entry: ColumnTopNEntry) { if capacity == 0 { return; } + if self.values.len() >= capacity { + let Some(min_entry) = self.prune_min() else { + return; + }; + entry.count = min_entry.count.saturating_add(entry.count); + entry.error = min_entry.count.saturating_add(entry.error); + } + let index = self.find(&entry.value).unwrap_or_else(|index| index); self.values.insert(index, entry); self.on_insert(index); @@ -159,15 +167,12 @@ impl ColumnTopN { } } - fn prune_min(&mut self) -> Option<()> { - let (min_index, next_min_index) = match self.min_index.take() { - Some(index) if index < self.values.len() => (index, None), - _ => self.find_min_and_next_index()?, - }; - self.values.remove(min_index); + fn prune_min(&mut self) -> Option { + let (min_index, next_min_index) = self.find_min_and_next_index()?; + let entry = self.values.remove(min_index); self.min_index = next_min_index.map(|index| if index > min_index { index - 1 } else { index }); - Some(()) + Some(entry) } fn on_insert(&mut self, index: usize) { @@ -232,19 +237,17 @@ mod tests { use crate::types::value::DataValue; #[test] - fn top_n_prunes_to_capacity() { + fn top_n_replaces_min_counter_when_full() { let mut top_n = ColumnTopN::default(); top_n.add_with_size(2, DataValue::Int32(1), 5); top_n.add_with_size(2, DataValue::Int32(2), 3); top_n.add_with_size(2, DataValue::Int32(3), 1); - assert_eq!(top_n.len(), 2); - - top_n.add_with_size(2, DataValue::Int32(4), 1); - assert_eq!(top_n.len(), 2); assert_eq!(top_n.get(&DataValue::Int32(1)), Some(5)); - assert_eq!(top_n.get(&DataValue::Int32(2)), Some(3)); + let entry = top_n.get_entry(&DataValue::Int32(3)).unwrap(); + assert_eq!(entry.count(), 4); + assert_eq!(entry.error(), 3); } #[test] @@ -261,7 +264,9 @@ mod tests { .windows(2) .all(|pair| pair[0].value() < pair[1].value())); assert_eq!(top_n.get(&DataValue::Int32(1)), Some(5)); - assert_eq!(top_n.get(&DataValue::Int32(2)), Some(1)); + let entry = top_n.get_entry(&DataValue::Int32(3)).unwrap(); + assert_eq!(entry.count(), 2); + assert_eq!(entry.error(), 1); } #[test] @@ -280,6 +285,27 @@ mod tests { assert_eq!(left.len(), 2); } + #[test] + fn top_n_merge_records_error_for_absent_entries() { + let mut left = ColumnTopN::default(); + left.add_with_size(2, DataValue::Int32(1), 5); + left.add_with_size(2, DataValue::Int32(2), 3); + + let mut right = ColumnTopN::default(); + right.add_with_size(2, DataValue::Int32(3), 4); + right.add_with_size(2, DataValue::Int32(4), 1); + + left.merge_with_size(right, 2); + + assert_eq!(left.len(), 2); + let entry = left.get_entry(&DataValue::Int32(3)).unwrap(); + assert_eq!(entry.count(), 7); + assert_eq!(entry.error(), 3); + let entry = left.get_entry(&DataValue::Int32(4)).unwrap(); + assert_eq!(entry.count(), 6); + assert_eq!(entry.error(), 5); + } + #[test] fn top_n_skips_null() { let mut top_n = ColumnTopN::default();