diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index 2141a96c..883ab1c4 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -367,6 +367,23 @@ impl Histogram { } } + fn top_n_count_or_fallback( + &self, + value: &DataValue, + sketch: &CountMinSketch, + top_n: &ColumnTopN, + ) -> usize { + let Some(entry) = top_n.get_entry(value) else { + return self.equal_count(value, sketch); + }; + if entry.error() == 0 || entry.error() < self.average_count() { + return entry.count(); + } + + let lower = entry.count().saturating_sub(entry.error()); + self.equal_count(value, sketch).clamp(lower, entry.count()) + } + pub fn collect_count( &self, ranges: &[Range], @@ -604,10 +621,8 @@ impl Histogram { Range::Eq(value) => { *count += if value.is_null() { self.meta.null_count - } else if let Some(count) = top_n.get(value) { - count } else { - self.equal_count(value, sketch) + self.top_n_count_or_fallback(value, sketch, top_n) }; *binary_i += 1 } @@ -1130,6 +1145,36 @@ mod tests { Ok(()) } + #[test] + fn test_eq_count_falls_back_when_top_n_error_is_large() -> Result<(), DatabaseError> { + let mut builder = HistogramBuilder::new(&index_meta(), ANALYZE_STATISTICS_RELATIVE_ERROR)?; + + for value in 0..10_000 { + builder.append(DataValue::Int32(value))?; + } + + let (histogram, sketch, _) = builder.build(100)?; + + let mut top_n = ColumnTopN::default(); + top_n.add_with_size(1, DataValue::Int32(1), 10); + top_n.add_with_size(1, DataValue::Int32(7), 1); + let entry = top_n.get_entry(&DataValue::Int32(7)).unwrap(); + assert_eq!(entry.count(), 11); + assert_eq!(entry.error(), 10); + + let fallback = histogram.equal_count(&DataValue::Int32(7), &sketch); + let lower = entry.count().saturating_sub(entry.error()); + let expected = fallback.clamp(lower, entry.count()); + assert!(expected < entry.count()); + + assert_eq!( + histogram.collect_count(&[Range::Eq(DataValue::Int32(7))], &sketch, &top_n)?, + expected + ); + + Ok(()) + } + #[test] fn test_collect_count_ignores_tuple_prefix_endpoint_count() -> Result<(), DatabaseError> { let mut builder = HistogramBuilder::new(&index_meta(), ANALYZE_STATISTICS_RELATIVE_ERROR)?; diff --git a/src/optimizer/core/top_n.rs b/src/optimizer/core/top_n.rs index eab8858c..62773c5e 100644 --- a/src/optimizer/core/top_n.rs +++ b/src/optimizer/core/top_n.rs @@ -70,14 +70,6 @@ impl ColumnTopN { self.add_with_options(top_n_size, value, count, 0); } - pub fn merge_with_size(&mut self, other: ColumnTopN, top_n_size: usize) { - for entry in other.values { - if self.should_insert(&entry.value, entry.count, entry.error) { - self.insert_new_with_options(top_n_size, entry); - } - } - } - pub fn finish_with_size(mut self, top_n_size: usize) -> Self { self.prune_to_capacity(top_n_size); self @@ -134,11 +126,19 @@ impl ColumnTopN { } } - fn insert_new_with_options(&mut self, capacity: usize, entry: ColumnTopNEntry) { + fn insert_new_with_options(&mut self, capacity: usize, mut entry: ColumnTopNEntry) { if capacity == 0 { return; } + if self.values.len() >= capacity { + let Some(min_entry) = self.prune_min() else { + return; + }; + entry.count = min_entry.count.saturating_add(entry.count); + entry.error = min_entry.count.saturating_add(entry.error); + } + let index = self.find(&entry.value).unwrap_or_else(|index| index); self.values.insert(index, entry); self.on_insert(index); @@ -159,15 +159,12 @@ impl ColumnTopN { } } - fn prune_min(&mut self) -> Option<()> { - let (min_index, next_min_index) = match self.min_index.take() { - Some(index) if index < self.values.len() => (index, None), - _ => self.find_min_and_next_index()?, - }; - self.values.remove(min_index); + fn prune_min(&mut self) -> Option { + let (min_index, next_min_index) = self.find_min_and_next_index()?; + let entry = self.values.remove(min_index); self.min_index = next_min_index.map(|index| if index > min_index { index - 1 } else { index }); - Some(()) + Some(entry) } fn on_insert(&mut self, index: usize) { @@ -232,19 +229,17 @@ mod tests { use crate::types::value::DataValue; #[test] - fn top_n_prunes_to_capacity() { + fn top_n_replaces_min_counter_when_full() { let mut top_n = ColumnTopN::default(); top_n.add_with_size(2, DataValue::Int32(1), 5); top_n.add_with_size(2, DataValue::Int32(2), 3); top_n.add_with_size(2, DataValue::Int32(3), 1); - assert_eq!(top_n.len(), 2); - - top_n.add_with_size(2, DataValue::Int32(4), 1); - assert_eq!(top_n.len(), 2); assert_eq!(top_n.get(&DataValue::Int32(1)), Some(5)); - assert_eq!(top_n.get(&DataValue::Int32(2)), Some(3)); + let entry = top_n.get_entry(&DataValue::Int32(3)).unwrap(); + assert_eq!(entry.count(), 4); + assert_eq!(entry.error(), 3); } #[test] @@ -261,23 +256,9 @@ mod tests { .windows(2) .all(|pair| pair[0].value() < pair[1].value())); assert_eq!(top_n.get(&DataValue::Int32(1)), Some(5)); - assert_eq!(top_n.get(&DataValue::Int32(2)), Some(1)); - } - - #[test] - fn top_n_merge_accumulates_entries() { - let mut left = ColumnTopN::default(); - left.add_with_size(2, DataValue::Int32(1), 3); - left.add_with_size(2, DataValue::Int32(2), 2); - - let mut right = ColumnTopN::default(); - right.add_with_size(2, DataValue::Int32(1), 4); - right.add_with_size(2, DataValue::Int32(3), 1); - - left.merge_with_size(right, 2); - - assert_eq!(left.get(&DataValue::Int32(1)), Some(7)); - assert_eq!(left.len(), 2); + let entry = top_n.get_entry(&DataValue::Int32(3)).unwrap(); + assert_eq!(entry.count(), 2); + assert_eq!(entry.error(), 1); } #[test]