Keeping track of your top $k$ objects using a heap
When I am writing code for my research, it sometimes happens that I need to keep track of the top $k$ objects, like the top 10 inputs that provided the best accuracy for my machine learning model, or the lowest loss.
Of course, if you can fit all your objects in memory, it is easy to use min() or max() to achieve this, but in some cases—like when you process data in batches—it’s more useful to maintain a “running tally.”
A naive solution to this simply consists of an ordered list of $k$ elements: Each time you process a new object, you compare your object against that list, insert it in the right position, and then remove the $k+1$-th element from the list. This works, but while finding the right position to insert your element has a time complexity of $\mathcal{O}(\log k)$ (think binary search), actually inserting it is in $\mathcal{O}(k)$.
Now, unless you want to keep a very long top $k$-list, that simple solution is probably still fine in practice.
However, we can actually do better (i.e., $\mathcal{O}(\log k)$), if we make use of a heap.
In the following, I am showing the solution that I cooked up and which gave me a welcome excuse to revisit Python’s built-in heapq module:
from typing import Any, List, Tuple
from heapq import heappush, heappushpop
class TopK:
"""
Keep a running tally of your top k elements (smallest or largest).
Args:
k: Number of elements to keep. Must be positive integer.
kind: Use "max" to keep track of the k largest or "min" for
the k smallest elements.
"""
def __init__(self, k: float, kind: str = 'max'):
assert isinstance(k, int) and k > 0, 'k must be a positive integer!'
assert kind in ('min', 'max'), 'kind must be "min" or "max"!'
self.k = k
self.kind = kind
self.sign = 1 if self.kind == 'max' else -1
self.heap = []
def add(self, key: float, value: Any) -> None:
"""
Add a new object to the list (but only if it's in the top k!).
Args:
key: The value by which to determine the top k.
value: The "payload" (i.e., the object associated to `key`).
"""
item = (self.sign * key, value)
# While we don't have k elements yet, we can just push to the heap
if len(self.heap) < self.k:
heappush(self.heap, item)
# Otherwise, we need to push and pop out the (k+1)-th element
else:
heappushpop(self.heap, item)
@property
def sorted(self) -> List[Tuple[float, Any]]:
"""
Auxiliary function to get sorted heap such that the first
element in the is the top-1 (i.e., the minimum or maximum).
"""
return [
(self.sign * k, v)
for k, v in sorted(self.heap, reverse=True)
]
def __repr__(self) -> str:
return str(self.sorted)
def __getitem__(self, idx) -> Tuple[float, Any]:
return self.sorted[idx]
I hope that most of it should be self-explanatory. Two things maybe deserve a few words, though.
First, the type of key.
I set it to float here, because it is simple and covers my own use cases, but of course, it should actually be “anything that can be compared.”
However, as of the writing of this post, the proper Comparable type for mypy is still in the works.
The other thing is the role of self.sign, and the way it is used to manipulate the sorting keys.
The reason for this is that the heapq module only implements a “min heap”, that is, a heap whose .pop() method gives us the smallest element.
However, if we want to keep a list of the smallest elements, we want to kick out the largest element using the heappushpop() operation.
Therefore, if kind == 'min', we simply use self.sign to flip the sign of the sorting keys to reverse the order.
Of course, when we return the (sorted) top $k$, we need to undo this again.
To show you that this actually works, here is a practical example:
>>> top_3 = TopK(k=3, kind='max')
>>> numbers = np.random.randint(0, 100, 10)
>>> numbers
[ 2 19 67 37 23 85 57 82 88 44]
>>> sorted(numbers)
[2, 19, 23, 37, 44, 57, 67, 82, 85, 88]
>>> for number in numbers:
top_3.add(key=number, value=str(uuid.uuid4())[:6])
>>> top_3
[(88, 'ca7844'), (85, 'de117e'), (82, 'd165ca')]
If you have found any errors in my solution, of have suggestions for improvements, please feel free to reach out to me! :-)