LeetCode 4. Median of Two Sorted Arrays题解

题目

4. Median of Two Sorted Arrays

There are two sorted arrays nums1 and nums2 of size m and n respectively.

Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

You may assume nums1 and nums2 cannot be both empty.

有两个有序数组nums1和nums2,大小分别为m和n。
请找到两个有序数组的中值。总的运行时间复杂度应该是O(log (m+n))。
假定nums1和nums2不能同时为空。

Example 1:

nums1 = [1, 3]
nums2 = [2]
The median is 2.0

Example 2:

nums1 = [1, 2]
nums2 = [3, 4]

The median is (2 + 3)/2 = 2.5

分析

笔者水平有限,刚看到这道题的时候甚至没看明白(不知道什么叫“两个有序数组的中值”),但是研究网上多种解答之后,找到了笔者认为讲的最好的一种,贴在这里。
而后附上笔者自己归纳得出的思路图解

描述:给定两个已经升序排序过的数组,求这两个数组的中位数;中位数的定义为把两个数组合并过后进行升序排序后,处于数组中间的那个数,此时如果合并后的数组元素个数为偶数,则为中间两个数的平均值。
初看起来这就是个寻找第k小数的问题,解决方案有很多,最简单的就是采用归并排序的思想把两个数组进行合并,然后取中间的数就可以了。但问题在于,这个题目限定了时间复杂度为O(log(m+n)),而合并算法的时间复杂度为O(nlogn),显然不合题意。另外一个方法是设置一个双指针,一开始都指向两个数组的开头,不停地比较两个指针指向的元素的大小,指向小元素的指针的往前移一个元素去追指向大元素的指针,一直移动(len1+len2)/2次后就能得到中位数,但是这个算法的时间复杂度仍然不符合题意,为O(n)

但是注意到这个题目给定的数组已经是排过序的了,算法导论中对order statistic问题进行过讨论,因此,在有序又要求log级的时间复杂度,可以考虑分治策略,采用二分法

大方向定好了,但是并不清楚具体要怎么去完成这个二分法,我们应该对什么去做二分?其实这个题目需要找的就是第k小的元素问题,假设我们的第k小的数是在第一个数组中找了p次,然后在第二个数组中找了q次,那么满足关系:p+q=k。进一步的,寻找第k小的数的过程就是寻找p和q的过程,k我们是知道的,但是p和q是不知道的,因此事实上我们的目标就是去搜索p(找到了p就等于找到了q),因此我们二分法的目标,事实上就是二分k来找p。

我们先定义以下形式的函数用来寻找第k小的数:

findKth(nums1, nums2, start1, len1, start2, len2, k)

nums1和nums2表示原始的两个数组,start1、len1表示nums1数组中以start1位置开始、len1长度的一个子数组;start2、lens2表示nums2数组中以start2位置开始、len2长度的一个子数组;k表示从这两个子数组中找到第k小的数。之所以提供start1、len1、start2、len2,是因为经验告诉我们分治法解决问题都是递归的,我们在二分的时候就需要记录这些相关的数据。

我们的外层入口就应该是这样子的:

if(len1+len2是偶数) {
	return (findKth(nums1, nums2, 0, len1, 0, len2, (len1+len2)/2) + 
	findKth(nums1, nums2, 0, len1, 0, len2, (len1+len2)/2 + 1)) / 2; 
} else {
	return findKth(nums1, nums2, 0, len1, 0, len2 (len1+len2)/2); 
	}

下面就是具体对于findKth的实现了。

首先,我们知道,我们需要对k进行二分:

findKth(nums1, nums2, start1, len1, start2, len2, k) {
	p = k / 2;
}

这个p怎么用呢?假设我们在nums1中取nums1[start1+p-1],就表示我们在nums1中“前进”了p个元素,且这p个元素是有序的,相应的,q = k - p,我们需要在nums2中前进q个元素:

findKth(nums1, nums2, start1, len1, start2, len2, k) {
  p = k / 2;
  q = k - p;
}

假如这个时候nums1[start1 + p - 1]等于nums2[start + q - 1],这说明kth = nums1[start1 + p - 1] = nums2[start + q - 1] = 第k小的数,为什么呢?这里要注意nums1和nums2都是有序的,因此它们的子数组也是有序的;假设把两个子数组数组合并,那么在nums1子数组中排在kth前面的数在合并后的数组中一定还是排在kth前面,同理在nums2子数组中排在kth前面的数在合并后的数组中也一定还是排在kth前面,它们的具体顺序我们不关心也不必关心,我们只需要知道这样一来在合并后的数组中就有p-1+q-1=k-2个数在kth前面,因此kth一定就是第k小的那个数

如果nums1[start1 + p - 1]大于nums2[start + q - 1],这里就出现了一个需要注意的情况,这意味着nums2子数组中的前q个数一定都是小于nums1[start1 + p - 1]的(再次注意,nums1和nums2都是有序的),而q<k,这也就意味着第k小的数一定不会出现在nums2的子数组的前q个数中。这启发我们在这个时候就可以抛弃掉前q个数,重新用一个新的子数组进行搜索,注意,进一步的搜索中由于抛弃掉了q(p)个数,因此下一步在子数组中的搜索中,事实上就是在搜索第k-q(k-p)小的元素了:

findKth(nums1, nums2, start1, len1, start2, len2, k) {
	 p = k / 2; 
	 q = k - p; 
	 if(nums1[start1 + p - 1] == nums2[start2 + q - 1]) { 
	 return nums1[start1 + p - 1]; 
	 } 
	 else if(nums1[start1 + p - 1] > nums2[start2 + q - 1]) { 
	 return findKth(nums1, nums2, start1, len1, start2 + q, len2 - q, k - q); 
	 } 
	 else if(nums1[start1 + p - 1] < nums2[start2 + q - 1]) { 
	 return findKth(nums1, nums2, start1 + p, len1 - p, start2, len2, k - p); 
	 } 
}

上面的框架大致已经描述清楚了我们的二分搜索算法,下一步就需要考虑退出条件。

退出条件有这么一些:

  • 在某一步搜索中子数组的长度为0了,这表示有一个数组中的元素完全被抛弃掉,此时另外一个子数组的第k个元素就是我们要求的第k小的元素;
  • 在不满足1的情况下,出现k=1的情况,这表示需要在两个子数组中找第1小的元素,此时简单地比较一下两个子数组的第一个元素就行了;
  • nums1[start1 + p - 1] == nums2[start2 + q - 1]

因此可以进一步写成:

findKth(nums1, nums2, start1, len1, start2, len2, k) { 
	if(某个子数组的长度为零) { 
		return 另外一个子数组的第k个元素; 
	} 
	if(k == 1) { 
	return min(nums1[start1], nums2[start2]); 
	} 

	p = k / 2; 
	q = k - p; 
	if(nums1[start1 + p - 1] == nums2[start2 + q - 1]) { 
		return nums1[start1 + p - 1]; 
	} 
	else if(nums1[start1 + p - 1] > nums2[start2 + q - 1]) { 
		return findKth(nums1, nums2, start1, len1, start2 + q, len2 - q, k - q); 
	} 
	else if(nums1[start1 + p - 1] < nums2[start2 + q - 1]) { 
		return findKth(nums1, nums2, start1 + p, len1 - p, start2, len2, k - p); 
	} 
}

为了方便考虑问题,不失一般性,我们要求nums1永远是那个长度较短的数组:

findKth(nums1, nums2, start1, len1, start2, len2, k) { 
	if(len1 > len2) { 
		return findKth(nums2, nums1, start2, len2, start2, start1); 
	} 
	
	if(len1 == 0) { 
		return nums2[start2 + k - 1]; 
	} 

	if(k == 1) { 
		return min(nums1[start1], nums2[start2]); 
	} 
	p = k / 2; 
	q = k - p; 
	if(nums1[start1 + p - 1] == nums2[start2 + q - 1]) { 
		return nums1[start1 + p - 1]; 
	} 
	else if(nums1[start1 + p - 1] > nums2[start2 + q - 1]) { 
		return findKth(nums1, nums2, start1, len1, start2 + q, len2 - q, k - q); 
	} 
	else if(nums1[start1 + p - 1] < nums2[start2 + q - 1]) { 
		return findKth(nums1, nums2, start1 + p, len1 - p, start2, len2, k - p); 
	} 
}

此外还有一个容易被忽略的边界问题,那就是p=k/2这一句,如果p大于len1的话,就会出现越界访问的问题,这个时候需要对其进行控制:

findKth(nums1, nums2, start1, len1, start2, len2, k) { 
	if(len1 > len2) { 
		return findKth(nums2, nums1, start2, len2, start2, start1); 
	} 

	if(len1 == 0) { 
		return nums2[start2 + k - 1]; 
	} 

	if(k == 1) { 
		return min(nums1[start1], nums2[start2]); 
	} 
	p = min(k / 2, len1); 
	q = k - p; 
	if(nums1[start1 + p - 1] == nums2[start2 + q - 1]) { 
		return nums1[start1 + p - 1]; 
	} 
	else if(nums1[start1 + p - 1] > nums2[start2 + q - 1]) { 
		return findKth(nums1, start1, len1, start2 + q, len2 - q, k - q); 
	} 
	else if(nums1[start1 + p - 1] > nums2[start2 + q - 1]) { 
		return findKth(nums1, start1 + p, len1 - p, start2, len2, k - p); 
	} 
}

分析到这里,基本上这个问题就解决了,不过需要说的是,对于p=min(k/2,len1)这一句,这里看起来应该就是二分法的比较关键的一个地方了,事实上我们把2换成3、4、5、6……都是可以的,因为二分法搜索事实上就是个碰运气的过程,不过需要注意的是,这里p不能为0,否则在nums1中等于是没有做“前进”的动作,这是不允许的,因此更加健壮的描述应该为:

p = min(max(k/2, 1), len1);

即二分过程中,每一次迭代至少要在nums1中“前进”一步。

整个程序的C++代码如下:

#include <vector> 
class Solution { 
private: 
	double findKth(vector<int>& nums1, vector<int>& nums2, int start1, int len1, int start2, int len2, int k) { 
		if (len1 > len2) { 
			return findKth(nums2, nums1, start2, len2, start1, len1, k); 
		} 

		if (len1 == 0) { 
			return nums2[start2 + k - 1]; 
		} 

		if (k == 1) { 
			return min(nums1[start1], nums2[start2]); 
		} 

		int p1 = min(k / 2, len1); 
		int p2 = k - p1; 
		if (nums1[start1 + p1 - 1] > nums2[start2 + p2 - 1]) { 
			return findKth(nums1, nums2, start1, len1, start2 + p2, len2 - p2, k - p2); 
		} 
		else if(nums1[start1 + p1 - 1] < nums2[start2 + p2 - 1]){ 
			return findKth(nums1, nums2, start1 + p1, len1 - p1, start2, len2, k - p1); 
		} 
		else { 
			return nums1[start1 + p1 - 1]; 
		} 
	} 
public: 
	double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) { 
		int len = nums1.size() + nums2.size(); 
		if (!(len & 0x01)) { 
			return (findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2) + 
				findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2 + 1) ) / 2.0f; 
		} 
		else { 
			return findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2 + 1); 
		} 
	} 
};

参考链接:https://www.jianshu.com/p/9bd57fd52062

下附思路图解
我们考虑如下两个有序数组:
nums1:{1,3,7,8,10}
nums2:{2,4,6,9}
如图:
LeetCode 4. Median of Two Sorted Arrays题解
首先把短的数组置前,方便我们代码统一:
不妨就假定现在nums1是{2,4,6,9},nums2是{1,3,7,8,10}
LeetCode 4. Median of Two Sorted Arrays题解
计算下两数组长度和为9,是奇数,所以取k = len/2 + 1,也就是k=5
意即我们需要找到两数组合并排序后的第5个数。
5/2=2,我们取nums1的前2个和nums2的前5-2=3个:LeetCode 4. Median of Two Sorted Arrays题解
比较nums1和nums2中取的数的最后一个:也就是比较4和7。
发现4比7小,所以我们先取出nums1中取的数2和4,把1,3,7重新放回我们的考虑范围中。
这时因为已经取出2个,所以k=5-2=3,也就是我们只要再取三个数就好了。
回到原来的方法,根据3/2=1,所以nums1中取1个,nums2中取3-1=2个:
LeetCode 4. Median of Two Sorted Arrays题解
此时发现6>3,所以将nums2中的1,3取出。k=3-2=1,也就是我们再取一个即可。
显然就取 nums1的第一项 与 nums2的第一项 中较小的那一个:
LeetCode 4. Median of Two Sorted Arrays题解
最终得到答案:6。


题解

c++实现
#include <vector> 
class Solution { 
private: 
	double findKth(vector<int>& nums1, vector<int>& nums2, int start1, int len1, int start2, int len2, int k) { 
		if (len1 > len2) { 
			return findKth(nums2, nums1, start2, len2, start1, len1, k); 
		} 

		if (len1 == 0) { 
			return nums2[start2 + k - 1]; 
		} 

		if (k == 1) { 
			return min(nums1[start1], nums2[start2]); 
		} 

		int p1 = min(k / 2, len1); 
		int p2 = k - p1; 
		if (nums1[start1 + p1 - 1] > nums2[start2 + p2 - 1]) { 
			return findKth(nums1, nums2, start1, len1, start2 + p2, len2 - p2, k - p2); 
		} 
		else if(nums1[start1 + p1 - 1] < nums2[start2 + p2 - 1]){ 
			return findKth(nums1, nums2, start1 + p1, len1 - p1, start2, len2, k - p1); 
		} 
		else { 
			return nums1[start1 + p1 - 1]; 
		} 
	} 
public: 
	double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) { 
		int len = nums1.size() + nums2.size(); 
		if (!(len & 0x01)) { 
			return (findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2) + 
				findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2 + 1) ) / 2.0f; 
		} 
		else { 
			return findKth(nums1, nums2, 0, nums1.size(), 0, nums2.size(), len / 2 + 1); 
		} 
	} 
};
java实现
public class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int m = nums1.length, n = nums2.length, left = (m + n + 1) / 2, right = (m + n + 2) / 2;
        return (findKth(nums1, 0, nums2, 0, left) + findKth(nums1, 0, nums2, 0, right)) / 2.0;
    }
    int findKth(int[] nums1, int i, int[] nums2, int j, int k) {
        if (i >= nums1.length) return nums2[j + k - 1];
        if (j >= nums2.length) return nums1[i + k - 1];
        if (k == 1) return Math.min(nums1[i], nums2[j]);
        int midVal1 = (i + k / 2 - 1 < nums1.length) ? nums1[i + k / 2 - 1] : Integer.MAX_VALUE;
        int midVal2 = (j + k / 2 - 1 < nums2.length) ? nums2[j + k / 2 - 1] : Integer.MAX_VALUE;
        if (midVal1 < midVal2) {
            return findKth(nums1, i + k / 2, nums2, j, k - k / 2);
        } else {
            return findKth(nums1, i, nums2, j + k / 2, k - k / 2);
        }
    }
}