Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "duckdb/common/vector/flat_vector.hpp"
#include "duckdb/common/vector/vector_iterator.hpp"
#include "duckdb/common/vector/list_vector.hpp"
#include "duckdb/common/vector/struct_vector.hpp"
#include "core_functions/aggregate/histogram_helpers.hpp"
#include "core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
Expand Down Expand Up @@ -81,7 +83,7 @@ struct InternalApproxTopKState {
filter.resize(filter_size);
}

static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) {
static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, ArenaAllocator &allocator) {
value.str_val.hash = input.hash;
if (input.str.IsInlined()) {
// no need to copy
Expand All @@ -92,7 +94,7 @@ struct InternalApproxTopKState {
if (value.size > value.capacity) {
// need to re-allocate for this value
value.capacity = UnsafeNumericCast<uint32_t>(NextPowerOfTwo(value.size));
value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity));
value.dataptr = char_ptr_cast(allocator.Allocate(value.capacity));
}
// copy over the data
memcpy(value.dataptr, input.str.GetData(), value.size);
Expand Down Expand Up @@ -129,7 +131,7 @@ struct InternalApproxTopKState {
filter[value.str_val.hash & filter_mask] = value.count;
lookup_map.erase(value.str_val);
}
CopyValue(value, input, aggr_input);
CopyValue(value, input, aggr_input.allocator);
lookup_map.insert(make_pair(value.str_val, reference<ApproxTopKValue>(value)));
IncrementCount(value, increment);
}
Expand Down Expand Up @@ -378,6 +380,168 @@ void ApproxTopKFinalize(Vector &state_vector, AggregateFinalizeInputData &, Vect
result.Verify();
}

//===--------------------------------------------------------------------===//
// State Export
//===--------------------------------------------------------------------===//
//! Exported state: STRUCT(k, values LIST(STRUCT(value, count)), filter) - the monitored values (descending count)
//! and the Filtered Space-Saving counters. Values are decoded to the input type on export, re-encoded on import.
AggregateStateLayout ApproxTopKGetStateType(AggregateLayoutInput &input) {
child_list_t<LogicalType> value_children;
value_children.emplace_back("value", input.function.GetArguments()[0]);
value_children.emplace_back("count", LogicalType::UBIGINT);

child_list_t<LogicalType> children;
children.emplace_back("k", LogicalType::UBIGINT);
children.emplace_back("values", LogicalType::LIST(LogicalType::STRUCT(std::move(value_children))));
children.emplace_back("filter", LogicalType::LIST(LogicalType::UBIGINT));

AggregateStateLayout layout;
layout.type = LogicalType::STRUCT(std::move(children));
layout.total_state_size = AlignValue<idx_t>(sizeof(ApproxTopKState));
return layout;
}

template <class OP = HistogramGenericFunctor>
void ApproxTopKExportState(Vector &state_vector, AggregateFinalizeInputData &aggr_input_data, Vector &result,
idx_t count, idx_t offset) {
D_ASSERT(offset == 0);
auto states = state_vector.Values<ApproxTopKState *>();

auto &mask = FlatVector::ValidityMutable(result);
auto &fields = StructVector::GetEntries(result);
auto k_data = FlatVector::GetDataMutable<uint64_t>(fields[0]);
auto &value_lists = fields[1];
auto &filter_lists = fields[2];
auto &k_validity = FlatVector::ValidityMutable(fields[0]);
auto &value_validity = FlatVector::ValidityMutable(value_lists);
auto &filter_validity = FlatVector::ValidityMutable(filter_lists);
auto value_entries = FlatVector::ScatterWriter<list_entry_t>(value_lists);
auto filter_entries = FlatVector::ScatterWriter<list_entry_t>(filter_lists);
idx_t total_values = ListVector::GetListSize(value_lists);
idx_t total_filters = ListVector::GetListSize(filter_lists);
for (idx_t i = 0; i < count; i++) {
auto state_ptr = states[i].GetValue()->state;
value_entries[i].offset = total_values;
filter_entries[i].offset = total_filters;
if (!state_ptr || state_ptr->values.empty()) {
// no values have been added to this state - export NULL (children of a NULL struct must also be NULL)
mask.SetInvalid(i);
k_validity.SetInvalid(i);
value_validity.SetInvalid(i);
filter_validity.SetInvalid(i);
value_entries[i].length = 0;
filter_entries[i].length = 0;
k_data[i] = 0;
continue;
}
k_data[i] = state_ptr->k;
value_entries[i].length = state_ptr->values.size();
filter_entries[i].length = state_ptr->filter.size();
total_values += state_ptr->values.size();
total_filters += state_ptr->filter.size();
}

ListVector::Reserve(value_lists, total_values);
ListVector::Reserve(filter_lists, total_filters);
auto &value_structs = ListVector::GetChildMutable(value_lists);
auto &value_fields = StructVector::GetEntries(value_structs);
auto &value_child = value_fields[0];
auto count_data = FlatVector::GetDataMutable<uint64_t>(value_fields[1]);
auto filter_data = FlatVector::GetDataMutable<uint64_t>(ListVector::GetChildMutable(filter_lists));
for (idx_t i = 0; i < count; i++) {
auto state_ptr = states[i].GetValue()->state;
if (!state_ptr || state_ptr->values.empty()) {
continue;
}
auto &state = *state_ptr;
// write the values (in descending count order) - decoding them back to the input type
idx_t value_offset = value_entries[i].offset;
for (auto &val_ref : state.values) {
auto &val = val_ref.get();
OP::template HistogramFinalize<string_t>(val.str_val.str, value_child, value_offset);
count_data[value_offset] = val.count;
value_offset++;
}
for (idx_t filter_idx = 0; filter_idx < state.filter.size(); filter_idx++) {
filter_data[filter_entries[i].offset + filter_idx] = state.filter[filter_idx];
}
}
ListVector::SetListSize(value_lists, total_values);
ListVector::SetListSize(filter_lists, total_filters);
FlatVector::SetSize(fields[0], count);
FlatVector::SetSize(value_lists, count);
FlatVector::SetSize(filter_lists, count);
FlatVector::SetSize(result, count);
}

template <class OP = HistogramGenericFunctor>
void ApproxTopKImportState(AggregateImportInputData &input) {
const auto &layout = input.layout;
const auto count = input.input_vec.size();
// the input can be any vector type (e.g. dictionary-encoded from a compressed scan) - flatten it so the
// struct/list children can be read by position
Vector input_vec(input.input_vec, 0, count);
input_vec.Flatten();
const auto dest_buffer = input.dest_buffer;
auto &allocator = input.allocator;
const auto validity = input_vec.Validity();
const auto &fields = StructVector::GetEntries(input_vec);
auto k_data = FlatVector::GetData<uint64_t>(fields[0]);
auto &value_lists = fields[1];
auto &filter_lists = fields[2];
auto value_entries = FlatVector::GetData<list_entry_t>(value_lists);
auto filter_entries = FlatVector::GetData<list_entry_t>(filter_lists);
auto &value_structs = ListVector::GetChild(value_lists);
const auto &value_fields = StructVector::GetEntries(value_structs);

// encode the values - this maps them to the same representation the state stores (e.g. as sort keys)
auto extra_state = OP::CreateExtraState();
UnifiedVectorFormat value_data;
OP::PrepareData(value_fields[0], extra_state, value_data);
auto value_strings = UnifiedVectorFormat::GetData<string_t>(value_data);
auto count_data = FlatVector::GetData<uint64_t>(value_fields[1]);
auto filter_data = FlatVector::GetData<uint64_t>(ListVector::GetChild(filter_lists));

for (idx_t i = 0; i < count; i++) {
auto &state = *reinterpret_cast<ApproxTopKState *>(dest_buffer + i * layout.total_state_size);
state.state = nullptr;
if (!validity.IsValid(i)) {
// NULL input - leave the state empty
continue;
}
auto internal_state = make_uniq<InternalApproxTopKState>();
auto &target = *internal_state;
target.Initialize(k_data[i]);
if (value_entries[i].length > target.capacity || filter_entries[i].length != target.filter.size()) {
throw InvalidInputException("Invalid approx_top_k state - "
"the values/filter sizes do not match the k value");
}
// insert the values - these are ordered by descending count, keeping "values" sorted
for (idx_t value_idx = 0; value_idx < value_entries[i].length; value_idx++) {
const auto idx = value_entries[i].offset + value_idx;
const auto sel_idx = value_data.sel->get_index(idx);
if (!value_data.validity.RowIsValid(sel_idx)) {
throw InvalidInputException("Invalid approx_top_k state - the state values cannot be NULL");
}

auto &val = target.stored_values[target.values.size()];
val.index = target.values.size();
target.values.push_back(val);

const auto &str_val = value_strings[sel_idx];
ApproxTopKString topk_string(str_val, Hash(str_val));
InternalApproxTopKState::CopyValue(val, topk_string, allocator);
target.lookup_map.insert(make_pair(val.str_val, reference<ApproxTopKValue>(val)));
val.count = count_data[idx];
}
for (idx_t filter_idx = 0; filter_idx < filter_entries[i].length; filter_idx++) {
target.filter[filter_idx] = filter_data[filter_entries[i].offset + filter_idx];
}
target.Verify();
state.state = internal_state.release();
}
}

unique_ptr<FunctionData> ApproxTopKBind(BindAggregateFunctionInput &input) {
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
Expand All @@ -389,7 +553,11 @@ unique_ptr<FunctionData> ApproxTopKBind(BindAggregateFunctionInput &input) {
if (arguments[0]->GetReturnType().id() == LogicalTypeId::VARCHAR) {
function.SetStateUpdateCallback(ApproxTopKUpdate<string_t, HistogramStringFunctor>);
function.SetStateFinalizeCallback(ApproxTopKFinalize<HistogramStringFunctor>);
function.SetExportAggregateStateCallback(ApproxTopKExportState<HistogramStringFunctor>);
function.SetImportAggregateStateCallback(ApproxTopKImportState<HistogramStringFunctor>);
}
// resolve the (originally ANY) value argument type so the exported state layout uses the actual input type
function.GetArguments()[0] = arguments[0]->GetReturnType();
function.SetReturnType(LogicalType::LIST(arguments[0]->GetReturnType()));
return nullptr;
}
Expand All @@ -399,11 +567,14 @@ unique_ptr<FunctionData> ApproxTopKBind(BindAggregateFunctionInput &input) {
AggregateFunction ApproxTopKFun::GetFunction() {
using STATE = ApproxTopKState;
using OP = ApproxTopKOperation;
return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT},
LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, ApproxTopKUpdate,
AggregateFunction::StateCombine<STATE, OP>, ApproxTopKFinalize, nullptr, ApproxTopKBind,
AggregateFunction::StateDestroy<STATE, OP>);
auto fun = AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT},
LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, ApproxTopKUpdate,
AggregateFunction::StateCombine<STATE, OP>, ApproxTopKFinalize, nullptr,
ApproxTopKBind, AggregateFunction::StateDestroy<STATE, OP>);
fun.SetStateExportCallbacks(ApproxTopKGetStateType, ApproxTopKExportState<HistogramGenericFunctor>,
ApproxTopKImportState<HistogramGenericFunctor>);
return fun;
}

} // namespace duckdb
Loading
Loading