LeetCode 解题报告 (23)-- 合并 k 个有序数组

原题如下:
>Merge k sorted linked lists and return it as one sorted list. Analyze and describe its complexity.

题目不难理解,就是将 k 个排好序的链表合并成一个,可以说是 merge two sorted lists 的升级版。一开始想的方法超时,后来参考了网上的两种方法并通过 python 实现后能够 AC,下面分别讲述这三种方法。

方法一:线性合并(TLE)

一开始想到的方法就是基于 merge two sorted lists 逐个合并 list,就是先将两个 list 合成一个,然后将这个合并好的 list 再和一个未合并的 list 进行 merge 操作,这样总共会合并 k-1 次,时间复杂度为 \(O(2n+3n+....+kn) = O(nk^2)\)(设 n 为链表的平均长度)。

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution(object):  
def mergeKLists(self, lists):
"""
:type lists: List[ListNode]
:rtype: ListNode
"""
k = len(lists)
if k == 0:
return None

mergedList = lists[0]
for i in range(1,k):
mergedList = self.mergeTwoLists(mergedList,lists[i])
return mergedList


def mergeTwoLists(self, l1, l2):
"""
:type l1: ListNode
:type l2: ListNode
:rtype: ListNode
"""
if l1 == None:
return l2
if l2 == None:
return l1

dummy = ListNode(0)
nextNode = dummy
while l1 and l2:
if l1.val > l2.val:
nextNode.next = l2
nextNode = nextNode.next
l2 = l2.next
else:
nextNode.next = l1
nextNode = nextNode.next
l1 = l1.next
if l1:
nextNode.next = l1
if l2:
nextNode.next = l2
return dummy.next

方法二:归并合并(AC)

这种方法类似于归并排序,先将当前需要排序的 list 对半分,重复这个步骤直到对半分出的 list 的数量为 1,在进行 merge,这时实际进行的是 merge two sorted list。

这种方法采用的思想跟第一种一样,都是分治法,先处理好局部,再合并成一个整体。但是与第一种方法在于这种方法在给出的 lists 的数量很大时需要进行 merge 的操作小于方法一。

这里有两个需要注意是当 lists 的数量很大(也就是 k 很大)是 merge 的操作才会比方法一要少。方法一无论 k 的大小 merge 的次数为 k-1,而方法二 merge 的次数为

\[\begin{align} \sum_{i=0}^m 2^i (m=log_2k-1) \end{align}\]

可以通过下面的程序验证当 k 很大时,这两种方法 merge 的次数不同

1
2
3
4
5
6
7
8
9
import math

k = 100000
m = int(math.log(k,2))
sum = 0
for i in range(m):
sum+=math.pow(2,i)
print 'merge times for method 1 when k=%s: %s'%(k,k-1)
print 'merge times for method 2 when k=%s: %s'%(k,sum)

根据主定理分析,这种方法的时间复杂度是 O(nklog(nk)), 下面是方法二实现的具体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution(object):  
def mergeKLists(self, lists):
"""
:type lists: List[ListNode]
:rtype: ListNode
"""
k = len(lists)
if k == 0:
return None
return self.helper(lists,0,len(lists)-1)


def helper(self,lists,l,r):
if l<r:
m = (r-l)/2
return self.mergeTwoLists(self.helper(lists,l,l+m),self.helper(lists,l+m+1,r))
else:
return lists[l]

def mergeTwoLists(self,l1,l2):
if l1 == None:
return l2
if l2 == None:
return l1

dummy = ListNode(0)
curr = dummy
while l1 and l2:
if l1.val < l2.val:
curr.next = l1
curr = curr.next
l1 = l1.next
else:
curr.next = l2
curr = curr.next
l2 = l2.next
if l1:
curr.next = l1
if l2:
curr.next = l2
return dummy.next

方法三:基于堆排序的归并(AC)

第三种方法非常巧妙,先建立一个大小为 k 的堆(k 就是链表数量),
堆中的一个元素代表一个链表当前的最小元素,每次取堆顶的最小元素放到结果中,然后读取该元素的下一个元素放入堆中,重新维护好。

因为每个链表是有序的,每次又是去当前 k 个元素中最小的,所以当所有链表都读完时结束,这个时候所有元素按从小到大放在结果链表中。
时间复杂度是 O (nklogk)。

实现代码如下,建堆的方法有两种,下面一并给出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Solution(object):  
def mergeKLists(self, lists):
"""
:type lists: List[ListNode]
:rtype: ListNode
"""
k = len(lists)
if k == 0:
return None

listsHeap = []
listsHeap.append(0) # 使堆中的元素从1开始

for i in range(k):
if lists[i] == None: # avoid empty list
continue
listsHeap.append(lists[i])

dummy = ListNode(0)
curr = dummy

# 初始化堆有两种方法
'''方法一
j = len(listsHeap) - 1
while j > 1:
if listsHeap[j].val<listsHeap[j/2].val:
listsHeap[j],listsHeap[j/2] = listsHeap[j/2],listsHeap[j]
self.siftDown(listsHeap,j) #必须,否则初始建的堆会有问题
j-=1
'''
# 方法二
leafParent = (len(listsHeap)-1)/2
for i in range(leafParent,0,-1):
siftDown(listsHeap,i)

# 取堆顶元素并调整堆
while len(listsHeap) > 1:
curr.next = listsHeap[1]
curr = curr.next
if listsHeap[1].next == None: # 将空的列表移到最后并删除
listsHeap[1] = listsHeap[len(listsHeap)-1]
del(listsHeap[len(listsHeap)-1])
else:
listsHeap[1] = listsHeap[1].next
self.siftDown(listsHeap,1)
return dummy.next

def siftDown(self,listsHeap,i):
while i*2+1 <= len(listsHeap):
if i*2+1 < len(listsHeap):
if listsHeap[i].val > min(listsHeap[i*2].val,listsHeap[i*2+1].val):
if listsHeap[i*2].val < listsHeap[i*2+1].val:
listsHeap[i],listsHeap[i*2] = listsHeap[i*2],listsHeap[i]
i = i*2
else:
listsHeap[i],listsHeap[i*2+1] = listsHeap[i*2+1],listsHeap[i]
i = i*2+1
else:
return
elif i*2+1 == len(listsHeap):
if listsHeap[i*2].val < listsHeap[i].val:
listsHeap[i],listsHeap[i*2] = listsHeap[i*2],listsHeap[i]
i = i*2
return


参考:http://blog.csdn.net/linhuanmars/article/details/19899259