diff --git a/util/topk/topk.go b/util/topk/topk.go index d3bbb2c6d..97a39a394 100644 --- a/util/topk/topk.go +++ b/util/topk/topk.go @@ -16,11 +16,12 @@ import ( // TopK is a probabilistic counter of the top K items, using a count-min sketch // to keep track of item counts and a heap to track the top K of them. -type TopK[T any] struct { - heap minHeap[T] - k int - sf SerializeFunc[T] - cms CountMinSketch +type TopK[T comparable] struct { + heap minHeap[T] + positions map[T]int + k int + sf SerializeFunc[T] + cms CountMinSketch } // HashFunc is responsible for providing a []byte serialization of a value, @@ -31,18 +32,19 @@ type SerializeFunc[T any] func([]byte, T) []byte // New creates a new TopK that stores k values. Parameters for the underlying // count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of // error. -func New[T any](k int, sf SerializeFunc[T]) *TopK[T] { +func New[T comparable](k int, sf SerializeFunc[T]) *TopK[T] { hashes, buckets := PickParams(0.001, 0.001) return NewWithParams(k, sf, hashes, buckets) } // NewWithParams creates a new TopK that stores k values, and additionally // allows customizing the parameters for the underlying count-min sketch. -func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { +func NewWithParams[T comparable](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { ret := &TopK[T]{ - heap: make(minHeap[T], 0, k), - k: k, - sf: sf, + heap: make(minHeap[T], 0, k), + positions: make(map[T]int, k), + k: k, + sf: sf, } ret.cms.init(numHashes, numCols) return ret @@ -69,21 +71,38 @@ func (tk *TopK[T]) AddN(val T, count uint64) uint64 { vcount := tk.cms.AddN(ser, count) - // If we don't have a full heap, just push it. + // Check if this item is already in the heap; if so, we can just update + // the count and fix the heap. + if pos, exists := tk.positions[val]; exists { + tk.heap[pos].count = vcount + heap.Fix(&tk.heap, pos) + return vcount + } + + // If we don't have a full heap, we add this item to the heap and + // return without checking the heap minimum. if len(tk.heap) < tk.k { + pos := len(tk.heap) heap.Push(&tk.heap, mhValue[T]{ count: vcount, val: val, }) + tk.positions[val] = pos return vcount } - // If this item's count surpasses the heap's minimum, update the heap. + // If this item's count surpasses the heap's minimum, replace the + // minimum value with this item. if vcount > tk.heap[0].count { + // Remove old item from positions map + delete(tk.positions, tk.heap[0].val) + + // Update heap tk.heap[0] = mhValue[T]{ count: vcount, val: val, } + tk.positions[val] = 0 heap.Fix(&tk.heap, 0) } return vcount diff --git a/util/topk/topk_test.go b/util/topk/topk_test.go index d30342e90..a5d0defef 100644 --- a/util/topk/topk_test.go +++ b/util/topk/topk_test.go @@ -68,6 +68,57 @@ func TestTopK(t *testing.T) { t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want) } +func TestTopKNoDuplicates(t *testing.T) { + // Create a TopK that tracks top 5 elements + topk := New[string](5, func(in []byte, val string) []byte { + return append(in, []byte(val)...) + }) + + // Add a single element many times + const commonElement = "very-common" + for i := 0; i < 500; i++ { + topk.Add(commonElement) + } + + // We should only have a single "top" element here, despite having + // added the same element 500 times. + if n := len(topk.Top()); n != 1 { + t.Errorf("expected only one element, got %d", n) + } + + // Add some less frequent elements + for i := 0; i < 5; i++ { + topk.Add(fmt.Sprintf("less-common-%d", i)) + } + + // Add common element again + for i := 0; i < 500; i++ { + topk.Add(commonElement) + } + + // Get the top elements + results := topk.Top() + + // Count occurrences of the common element + commonCount := 0 + for _, res := range results { + if res == commonElement { + commonCount++ + } + } + + if commonCount > 1 { + t.Errorf("common element appeared %d times in results, want 1", commonCount) + } else if commonCount == 0 { + t.Error("common element did not appear in results") + } + + // We expect that the common element is last (i.e. "top") in the returned list. + if idx := len(results) - 1; results[idx] != commonElement { + t.Errorf("common element not last in results: %q", results[idx]) + } +} + func TestPickParams(t *testing.T) { hashes, buckets := PickParams( 0.001, // 0.1% error rate