LeetCode 解题报告 (382, 398)-- 随机采样算法 Reservoir Sampling

Reservoir sampling 是一个随机采样算法,简单来说就是从 \(n\) 个 items 中随机选择 \(k\) 个 items,并且每个 item 被选择的概率应该都一样。这个算法的优点在于时空复杂度都不高,其中时间复杂度为 \(O(n)\), 空间复杂度为 \(O(1)\)下面介绍该算法的过程,并且以 leetcode 上的两道题目为例讲解。

假设现在要从 \(n\) 个数里面随机选择一个数,那么通过 Reservoir sampling 选择的流程如下

  1. 记最终选择的数为 result
  2. 遍历数组,对于数组第 i 个数,以 \(1/i\) 的概率将其赋给 result(i 从 1 开始,所以第一个数肯定会赋给 result)
  3. 遍历完数组后得到的 result 即为产生的随机数

假设现在有数组 [1, 2, 3], 随机产生一个数,那么按照上面的流程有 1. 遍历第一个数时,result = 1 2. 遍历第二个数时,result = 2 的概率为 1/2, 即 result = 1 的概率也是 1/2 3. 遍历第三个数时,result = 3 的概率为 1/3, result = 1 的概率为 (1 - 1/3) * 1/2 = 1/3, 同理 result = 2 的概率也是 1/3

上面的 (1 - 1/3) * 1/2 指的是这一次没有选择第三个数且之前 result 的值为 1 的概率,通过数学归纳法可以很容易的证明遍历完整个数组后每个数被选择的概率是 1/n (n 为数组的长度)

而假如要从 \(n\) 个数里面随机选择 \(k\) 个数时,Reservoir sampling 的过程类似上面的

  1. 选择前 \(k\) 个数作为 result
  2. 从第 \(k+1\) 个数开始遍历数组,对于数组第 \(k+i (i = 1,2,.....)\) 个数,以 \(\frac{k}{k+i}\) 的概率选择这个数加入 result 并替换掉 result 中的任意一个数
  3. 遍历完数组后得到的 result 即为产生的 \(k\) 个随机样本

下面通过数学归纳法证明通过上面的算法过程最终每个数被选择的概率为 \(k/n\)

  1. \(i = 1\) 时,选择第 \(k+1\) 个数的概率为 \(\frac{k}{k+1}\),而在 result 中 \(k\) 个数里面的一个 (记为 \(x\)) 能够保留下来的概率为是 \(x\) 原来在 result 中且这一次没有被替换的概率,而这一次没有被替换掉又可分为两种情况,一种是根本没有选择到第 \(k+i\) 个数,一种是选择了第 \(k+i\) 个数,但是替换 \(k\) 个数中的一个时没有替换掉 \(x\)。公式表示为

\[\begin {align} p ( x 上一次在 result 中) \* p ( x 没有被替换掉) = 1 \*(\frac {k}{k+1} \* (1-\frac {1}{k}) + (1 - \frac {k}{k+1}))= \frac {k}{k+1} \end {align}\]

即每个数被选择的概率为 \(\frac{k}{k+1}\)

  1. 因此当 \(i = m\) 时,每个数被选择的概率为 \(k/(k+m)\)

  2. 则当 \(i = m+1\) 时,选择第 \(k+m+1\) 个数的概率为 \(\frac{k}{k+m+1}\), 在 result 中 \(k\) 个数里面的一个 (记为 \(x\)) 能够保留下来的概率为:

\[\begin {align} p ( x 上一次在 result 中) \* p ( x 没有被替换掉) = \frac {k}{k+m} \*(\frac {k}{k+m+1} \* (1- \frac {1}{k}) + (1 - \frac {k}{k+m+1}))= \frac {k}{k+m+1} \end {align}\]

从上可知,遍历到第 \(i\) 个数的时候,前 \(k+i\) 每个数被选择的概率为 \(k/(k+i)\), 则遍历完 \(n\) 个数后,每个数被选择的概率为 \(k/n\)

LeetCode 上的题目 382. Linked List Random Node398. Random Pick Index 均用到了 Reservoir Sampling 的技巧,上面的依概率选择可以通过产生随机数并与概率值比较来实现,下面分别是 解决 382. Linked List Random Node 和 398. Random Pick Index 的 Java 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// 382. Linked List Random Node
public class Solution {

private ListNode dummy = new ListNode(0);

public Solution(ListNode head)
{
dummy.next = head;
}


public int getRandom()
{
ListNode curr = dummy.next;
int count = 0, result = 0;
while (curr != null)
{
count ++;
if (Math.random() < 1.0/count) result = curr.val;
curr = curr.next;
}
return result;
}
}

// 398. Random Pick Index
public class Solution
{
private int[] numbers;

public Solution(int[] nums)
{
numbers = nums;
}

public int pick(int target)
{
int index = 0, count = 0;
for(int i = 0; i < numbers.length; i++)
{
if (numbers[i] == target)
{
count++;
if(Math.random() < 1.0/count) index = i;
}
}
return index;
}
}