Median of Two Sorted Arrays

https://leetcode.com/problems/median-of-two-sorted-arrays/

There are two sorted arrays nums1 and nums2 of size m and n respectively.

Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

Example 1:

nums1 = [1, 3]
nums2 = [2]

The median is 2.0

Example 2:

nums1 = [1, 2]
nums2 = [3, 4]

The median is (2 + 3)/2 = 2.5

思路:这道题可以用一个通用解法quick select的衍生解法来解。 我们先介绍一下Quick select, 这是一种在平均O(n)时间复杂度下在unsorted array中找到kth largest element的算法,最差的情况时间复杂度是O(n^2)。

这种算法的思想跟quick sort很像。我们首先随便选一个pivot,然后把比他小的放到左边,比他小的放到右边,跟他相同的放中间。然后判断k在三块中的哪一块,不停地减小搜索范围。

class Solution(object):
    def quickSelect(self, nums, k):

        if k>len(nums) or k<=0:
            return

        pivot = nums[0]
        left = [i for i in nums if i<pivot]
        mid = [i for i in nums if i==pivot]
        right = [i for i in nums if i>pivot]

        if len(left)>=k:
            return self.quickSelect(mid+left, k)
        elif len(mid)>=k-len(left):
            return pivot
        else:
            return self.quickSelect(right, k-len(left)-len(mid))

nums = [1,2,3,4]
so = Solution()
ans = so.quickSelect(nums, 0)
print(ans)

了解了quick select下面这道题目就变成了find kth element in 2 sorted array。主要思想是通过比较meduim,每次至少砍掉一个数组的一半的数。然后update k, 直到把一个数组砍没。

怎么实现每次把一个数组砍半呢?我们每次取a的中间的index叫做amid 和b中间的index叫做bmid,然后a和b就根据amid和bmid被切成了4个如下小块,a的前半部分和a的后半部分,b的前半部分和b的后半部分。

a[:amid], a[amid:]
b[:bmid], b[bmid:]

然后我们要每次要丢掉一个小块。这样每一个循环我们就都在缩小范围。然后我们要保证每次丢掉一小块的时候kth element都在剩下的块里面。

所以我们先用amid+bmid和k比较,如果amid+bmid比k大说明,k在左边三块中,我们可以把最大的那一小块给删掉。所以通过比较a[amid]和b[bmid]就可以找出较大的那一小块。如果amid+bmid比k小说明,k在右边三块中,我们可以把最小的那一小块给找出来删掉。所以通过比较a[amid]和b[bmid]中小的那个的mid value的左边就是最小的那一块。

这里要注意的是,等于情况怎么处理。因为我们的index是向下取整的。所以如果碰到amid+bmid=k的话,应该算mide在左边的三块里面,所以删出右边的那一块。

这样就保证了,每次都在减少,每次解都在剩下的数里面。感觉还是没讲清楚,我们还是来看代码吧。

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        l = len(nums1) + len(nums2)
        if l%2==1:
            return self.kth(nums1, nums2, l//2)
        else:
            return (self.kth(nums1, nums2, l//2-1) + self.kth(nums1,nums2, l//2))//2


    def kth(self, a, b, k):
        if not a:
            return b[k]
        if not b:
            return a[k]

        # mid index of a and b
        ma, mb = len(a)//2, len(b)//2

        if ma+mb>=k: # mid is on the left
            if a[ma]>b[mb]:
                return self.kth(a[:ma], b, k)
            else:
                return self.kth(a, b[:mb], k)
        else:   # mid is on the right
            if a[ma]>b[mb]:
                return self.kth(a, b[mb+1:], k-mb-1)
            else:
                return self.kth(a[ma+1:], b, k-ma-1)

results matching ""

    No results matching ""