秒杀经典TopK问题

Top K问题是解决如何在数组中或者数据中获得前K个最大或者最小的元素,是面试时的高频问点。

问题的具体化形式包括但不限于以下几个:

解决这一类问题,我们通常有如下几种解题方法:

  • 排序法

  • 快速查找法

下面对于每种方法给出解题思路,对于这一类问题可以直接套用。

一. 排序法

排序法是最简单的也是最容易想到的方法。这种方法可以依赖编程语言的API排序函数进行排序,然后选择对应元素即可。

解题的模板如下:

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        # 对集合进行排序
        nums.sort()
        # 返回符合要求的元素
        return nums[k]

对于215. 数组中的第K个最大元素,这里可以使用此方法直接返回排序后的第K个索引的元素,即是算法的结果。

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        nums.sort(reverse=True)
        return nums[k-1]

对于剑指 Offer 40. 最小的k个数,这里返回的不是一个数字,而是返回一个列表,其中包含最小的K个数字。也可以使用这个模板进行作答。

class Solution:
    def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
        arr.sort()
        return arr[:k]

对于973. 最接近原点的 K 个点,别管是求什么距离最近,脱去外衣还是求最小的K个数字,和上一题不同的是,这里的排序条件要改变一下,自定义个求距离的函数。

class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        points.sort(key=lambda x: (x[0] ** 2 + x[1] ** 2))
        return points[:K]

通过上述三个体可以看出 ,这种方法最简单也是最直接的,运行效果也不错。

二. 堆

要获得前K个元素,而且是最大或者最小。我们第一反应可以想到堆排序。我们用一个大根堆实时维护数组的前 k小值。首先将前 k 个数插入大根堆中,随后从第 k+1 个数开始遍历,如果当前遍历到的数比大根堆的堆顶的数要小,就把堆顶的数弹出,再插入当前遍历到的数。最后将大根堆里的数存入数组返回即可。由于 C++ 语言中的堆(即优先队列)为大根堆,我们可以这么做。而 Python 语言中的对为小根堆,因此我们要对数组中所有的数取其相反数,才能使用小根堆维护前 k 小值。

对于215. 数组中的第K个最大元素

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        heap = [x for x in nums[:k]]
        heapq.heapify(heap)

        for num in range(k, len(nums)):
            if heap[0] < nums[num]:
                heapq.heappushpop(heap, nums[num])
        return heap[0]

对于剑指 Offer 40. 最小的k个数

class Solution:
    def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
        if k == 0:
            return []
        heap = [-num for num in arr[:k]]
        heapq.heapify(heap)
        for num in range(k, len(arr)):
            if -heap[0] > arr[num]:
                heapq.heappushpop(heap, -arr[num])
        res = [-num for num in heap]
        return res

对于973. 最接近原点的 K 个点

import heapq
class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        res = []
        q = [(x[0] ** 2 + x[1] ** 2, i) for i, x in enumerate(points)]
        heapq.heapify(q)
        for point in range(K):
            res.append(points[heapq.heappop(q)[1]])
        return res

总结一点:求最大要用小根堆,求最小要用大根堆

为什么呢?请独立思考!

对于海量数据,我们不需要一次性将全部数据取出来,可以一次只取一部分,因为我们只需要将数据一个个拿来与堆顶比较。

复杂度分析

  • 时间复杂度:O(nlogk),其中 n 是数组 arr 的长度。由于大根堆实时维护前 k 小值,所以插入删除都是 O(logk) 的时间复杂度,最坏情况下数组里 n 个数都会插入,所以一共需要 O(nlogk) 的时间复杂度。
  • 空间复杂度:O(k),因为大根堆里最多 k 个数。

三. 快速查找法

我们可以借鉴快速排序的思想。我们知道快排的划分函数每次执行完后都能将数组分成两个部分,小于等于分界值 pivot 的元素的都会被放到数组的左边,大于的都会被放到数组的右边,然后返回分界值的下标。与快速排序不同的是,快速排序会根据分界值的下标递归处理划分的两侧,而这里我们只处理划分的一边。

Top K 问题的这个解法就比较难想到,需要在平时有算法的积累。找第 k 大的数,或者找前 k 大的数,有一个经典的 quick select(快速选择)算法。这个名字和 quick sort(快速排序)看起来很像,算法的思想也和快速排序类似,都是分治法的思想。

partition 操作是原地进行的,需要 O(n) 的时间,接下来,快速排序会递归地排序左右两侧的数组。而快速选择(quick select)算法的不同之处在于,接下来只需要递归地选择一侧的数组。快速选择算法想当于一个“不完全”的快速排序,因为我们只需要知道最小的 k 个数是哪些,并不需要知道它们的顺序。

这种方法需要多加领会思想,如果你对快速排序掌握得很好,那么稍加推导应该不难掌握 quick select 的要领。

对于215. 数组中的第K个最大元素

class Solution:
    def partition(self, nums, low, high):
        pivot = nums[high]
        i = low - 1
        for j in range(low, high):
            if nums[j] < pivot:
                i += 1
                nums[i], nums[j] = nums[j], nums[i]
        nums[high], nums[i+1] = nums[i+1], nums[high]
        return i + 1

    def quickSelection(self, nums, left, right, k):
        pivot = self.partition(nums, left, right)
        if pivot == k:
            return nums[pivot]
        elif pivot < k:
            return self.quickSelection(nums, pivot+1, right, k)
        else:
            return self.quickSelection(nums, left, pivot-1, k)

    def findKthLargest(self, nums, k):
        return self.quickSelection(nums, 0, len(nums) - 1, len(nums) - k)

对于剑指 Offer 40. 最小的k个数

class Solution:
    def partition(self, nums, l, r):
        pivot = nums[r]
        i = l - 1
        for j in range(l, r):
            if nums[j] <= pivot:
                i += 1
                nums[i], nums[j] = nums[j], nums[i]
        nums[i + 1], nums[r] = nums[r], nums[i + 1]
        return i + 1

    def randomized_partition(self, nums, l, r):
        i = random.randint(l, r)
        nums[r], nums[i] = nums[i], nums[r]
        return self.partition(nums, l, r)

    def randomized_selected(self, arr, l, r, k):
        pos = self.randomized_partition(arr, l, r)
        num = pos - l + 1
        if k < num:
            self.randomized_selected(arr, l, pos - 1, k)
        elif k > num:
            self.randomized_selected(arr, pos + 1, r, k - num)

    def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
        if k == 0:
            return list()
        self.randomized_selected(arr, 0, len(arr) - 1, k)
        return arr[:k]

我们的目的是寻找最小的 k个数。假设经过一次 partition 操作,枢纽元素位于下标 m,也就是说,左侧的数组有 m 个元素,是原数组中最小的 m 个数。那么:

  • k = m,我们就找到了最小的 k 个数,就是左侧的数组;
  • k<m ,则最小的 k 个数一定都在左侧数组中,我们只需要对左侧数组递归地 partition即可;
  • k>m,则左侧数组中的 m 个数都属于最小的 k 个数,我们还需要在右侧数组中寻找最小的 k-m 个数,对右侧数组递归地 partition 即可。

对于973. 最接近原点的 K 个点

class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        def random_select(left: int, right: int, K: int):
            pivot_id = random.randint(left, right)
            pivot = points[pivot_id][0] ** 2 + points[pivot_id][1] ** 2
            points[right], points[pivot_id] = points[pivot_id], points[right]
            i = left - 1
            for j in range(left, right):
                if points[j][0] ** 2 + points[j][1] ** 2 <= pivot:
                    i += 1
                    points[i], points[j] = points[j], points[i]
            i += 1
            points[i], points[right] = points[right], points[i]
            # [left, i-1] 都小于等于 pivot, [i+1, right] 都大于 pivot
            if K < i - left + 1:
                random_select(left, i - 1, K)
            elif K > i - left + 1:
                random_select(i + 1, right, K - (i - left + 1))

        n = len(points)
        random_select(0, n - 1, K)
        return points[:K]

我们定义函数 random_select(left, right, K) 表示划分数组 points[left,right] 区间,并且需要找到其中第 K 个距离最小的点。在一次划分操作完成后,设 pivot 的下标为 i,即区间 [left,i−1] 中的点的距离都小于等于pivot,而区间 [i+1,right] 的点的距离都大于pivot。此时会有三种情况:

  • 如果 K=i−left+1,那么说明 pivot 就是第 K 个距离最小的点,我们可以结束整个过程;
  • 如果K<i−left+1,那么说明第 K 个距离最小的点在 pivot 左侧,因此递归调用 random_select(left, i - 1, K);
  • 如果 K>i−left+1,那么说明第 K 个距离最小的点在 pivot 右侧,因此递归调用 random_select(i + 1, right, K - (i - left + 1))

在整个过程结束之后,第 K 个距离最小的点恰好就在数组 points 中的第 K 个位置,并且其左侧的所有点的距离都小于它。此时,我们就找到了前 K 个距离最小的点。

复杂度分析

  • 时间复杂度:期望为 O(n)
  • 空间复杂度:期望为 O(logn)

参考文章

快排亲兄弟:快速选择算法详解


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!