秒杀经典TopK问题
Top K问题是解决如何在数组中或者数据中获得前K个最大或者最小的元素,是面试时的高频问点。
问题的具体化形式包括但不限于以下几个:
解决这一类问题,我们通常有如下几种解题方法:
排序法
堆
快速查找法
下面对于每种方法给出解题思路,对于这一类问题可以直接套用。
一. 排序法
排序法是最简单的也是最容易想到的方法。这种方法可以依赖编程语言的API排序函数进行排序,然后选择对应元素即可。
解题的模板如下:
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
# 对集合进行排序
nums.sort()
# 返回符合要求的元素
return nums[k]
对于215. 数组中的第K个最大元素,这里可以使用此方法直接返回排序后的第K个索引的元素,即是算法的结果。
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
nums.sort(reverse=True)
return nums[k-1]
对于剑指 Offer 40. 最小的k个数,这里返回的不是一个数字,而是返回一个列表,其中包含最小的K个数字。也可以使用这个模板进行作答。
class Solution:
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
arr.sort()
return arr[:k]
对于973. 最接近原点的 K 个点,别管是求什么距离最近,脱去外衣还是求最小的K个数字,和上一题不同的是,这里的排序条件要改变一下,自定义个求距离的函数。
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
points.sort(key=lambda x: (x[0] ** 2 + x[1] ** 2))
return points[:K]
通过上述三个体可以看出 ,这种方法最简单也是最直接的,运行效果也不错。
二. 堆
要获得前K个元素,而且是最大或者最小。我们第一反应可以想到堆排序。我们用一个大根堆实时维护数组的前 k
小值。首先将前 k
个数插入大根堆中,随后从第 k+1
个数开始遍历,如果当前遍历到的数比大根堆的堆顶的数要小,就把堆顶的数弹出,再插入当前遍历到的数。最后将大根堆里的数存入数组返回即可。由于 C++ 语言中的堆(即优先队列)为大根堆,我们可以这么做。而 Python 语言中的对为小根堆,因此我们要对数组中所有的数取其相反数,才能使用小根堆维护前 k
小值。
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
heap = [x for x in nums[:k]]
heapq.heapify(heap)
for num in range(k, len(nums)):
if heap[0] < nums[num]:
heapq.heappushpop(heap, nums[num])
return heap[0]
class Solution:
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
if k == 0:
return []
heap = [-num for num in arr[:k]]
heapq.heapify(heap)
for num in range(k, len(arr)):
if -heap[0] > arr[num]:
heapq.heappushpop(heap, -arr[num])
res = [-num for num in heap]
return res
import heapq
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
res = []
q = [(x[0] ** 2 + x[1] ** 2, i) for i, x in enumerate(points)]
heapq.heapify(q)
for point in range(K):
res.append(points[heapq.heappop(q)[1]])
return res
总结一点:求最大要用小根堆,求最小要用大根堆
为什么呢?请独立思考!
对于海量数据,我们不需要一次性将全部数据取出来,可以一次只取一部分,因为我们只需要将数据一个个拿来与堆顶比较。
复杂度分析
- 时间复杂度:
O(nlogk)
,其中n
是数组arr
的长度。由于大根堆实时维护前k
小值,所以插入删除都是O(logk)
的时间复杂度,最坏情况下数组里n
个数都会插入,所以一共需要O(nlogk)
的时间复杂度。
- 空间复杂度:
O(k)
,因为大根堆里最多k
个数。
三. 快速查找法
我们可以借鉴快速排序的思想。我们知道快排的划分函数每次执行完后都能将数组分成两个部分,小于等于分界值 pivot
的元素的都会被放到数组的左边,大于的都会被放到数组的右边,然后返回分界值的下标。与快速排序不同的是,快速排序会根据分界值的下标递归处理划分的两侧,而这里我们只处理划分的一边。
Top K 问题的这个解法就比较难想到,需要在平时有算法的积累。找第 k 大的数,或者找前 k 大的数,有一个经典的 quick select(快速选择)算法。这个名字和 quick sort(快速排序)看起来很像,算法的思想也和快速排序类似,都是分治法的思想。
partition
操作是原地进行的,需要 O(n)
的时间,接下来,快速排序会递归地排序左右两侧的数组。而快速选择(quick select)算法的不同之处在于,接下来只需要递归地选择一侧的数组。快速选择算法想当于一个“不完全”的快速排序,因为我们只需要知道最小的 k
个数是哪些,并不需要知道它们的顺序。
这种方法需要多加领会思想,如果你对快速排序掌握得很好,那么稍加推导应该不难掌握 quick select 的要领。
class Solution:
def partition(self, nums, low, high):
pivot = nums[high]
i = low - 1
for j in range(low, high):
if nums[j] < pivot:
i += 1
nums[i], nums[j] = nums[j], nums[i]
nums[high], nums[i+1] = nums[i+1], nums[high]
return i + 1
def quickSelection(self, nums, left, right, k):
pivot = self.partition(nums, left, right)
if pivot == k:
return nums[pivot]
elif pivot < k:
return self.quickSelection(nums, pivot+1, right, k)
else:
return self.quickSelection(nums, left, pivot-1, k)
def findKthLargest(self, nums, k):
return self.quickSelection(nums, 0, len(nums) - 1, len(nums) - k)
class Solution:
def partition(self, nums, l, r):
pivot = nums[r]
i = l - 1
for j in range(l, r):
if nums[j] <= pivot:
i += 1
nums[i], nums[j] = nums[j], nums[i]
nums[i + 1], nums[r] = nums[r], nums[i + 1]
return i + 1
def randomized_partition(self, nums, l, r):
i = random.randint(l, r)
nums[r], nums[i] = nums[i], nums[r]
return self.partition(nums, l, r)
def randomized_selected(self, arr, l, r, k):
pos = self.randomized_partition(arr, l, r)
num = pos - l + 1
if k < num:
self.randomized_selected(arr, l, pos - 1, k)
elif k > num:
self.randomized_selected(arr, pos + 1, r, k - num)
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
if k == 0:
return list()
self.randomized_selected(arr, 0, len(arr) - 1, k)
return arr[:k]
我们的目的是寻找最小的 k
个数。假设经过一次 partition
操作,枢纽元素位于下标 m
,也就是说,左侧的数组有 m
个元素,是原数组中最小的 m
个数。那么:
- 若
k = m
,我们就找到了最小的k
个数,就是左侧的数组; - 若
k<m
,则最小的k
个数一定都在左侧数组中,我们只需要对左侧数组递归地partition
即可; - 若
k>m
,则左侧数组中的m
个数都属于最小的k
个数,我们还需要在右侧数组中寻找最小的k-m
个数,对右侧数组递归地partition
即可。
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
def random_select(left: int, right: int, K: int):
pivot_id = random.randint(left, right)
pivot = points[pivot_id][0] ** 2 + points[pivot_id][1] ** 2
points[right], points[pivot_id] = points[pivot_id], points[right]
i = left - 1
for j in range(left, right):
if points[j][0] ** 2 + points[j][1] ** 2 <= pivot:
i += 1
points[i], points[j] = points[j], points[i]
i += 1
points[i], points[right] = points[right], points[i]
# [left, i-1] 都小于等于 pivot, [i+1, right] 都大于 pivot
if K < i - left + 1:
random_select(left, i - 1, K)
elif K > i - left + 1:
random_select(i + 1, right, K - (i - left + 1))
n = len(points)
random_select(0, n - 1, K)
return points[:K]
我们定义函数 random_select(left, right, K)
表示划分数组 points
的 [left,right]
区间,并且需要找到其中第 K
个距离最小的点。在一次划分操作完成后,设 pivot
的下标为 i
,即区间 [left,i−1]
中的点的距离都小于等于pivot
,而区间 [i+1,right]
的点的距离都大于pivot
。此时会有三种情况:
- 如果
K=i−left+1
,那么说明pivot
就是第K
个距离最小的点,我们可以结束整个过程; - 如果
K<i−left+1
,那么说明第K
个距离最小的点在pivot
左侧,因此递归调用random_select(left, i - 1, K);
- 如果
K>i−left+1
,那么说明第K
个距离最小的点在pivot
右侧,因此递归调用random_select(i + 1, right, K - (i - left + 1))
。
在整个过程结束之后,第 K
个距离最小的点恰好就在数组 points
中的第 K
个位置,并且其左侧的所有点的距离都小于它。此时,我们就找到了前 K
个距离最小的点。
复杂度分析
- 时间复杂度:期望为
O(n)
- 空间复杂度:期望为
O(logn)
参考文章
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!