吴良超的学习笔记

Binary Indexed Trees 简介

Binary Indexed Trees(中文名为树状数组,下文简称为BIT)是一种特殊的数据结构,可多用于高效计算数列的前缀和, 区间和。对于长度为n的数组,它可以以$O(logn)$的时间得到任意前缀和 $ {\displaystyle \sum _{i=1}^{j}a[i],1<=j<=N}$,并同时支持在 $ O(log n)$时间内支持动态单点值的修改。空间复杂度 $O(n)$

虽然BIT名称中带有tree这个词,但是实际存储时是利用两个数组进行存储,记这两个数组为numsBIT。假设我们现在需要对原始数组 arr 进行前缀求和和区间求和,那么可以按照以下步骤进行。

1.初始化

$nums[i] = arr[i]$
$BIT[i] = {\displaystyle \sum _{k=i-lowestbit(i)+1}^{i}arr[k]}$

上面的lowestbit(i)指将i转为二进制后,最后一个1的位置所代表的数值。如lowestbit(1)=1、lowestbit(6)=2,具体的实现可通过(i&-i)获取。

下图就是初始化后的情况,横轴为数组的下标(记为i),纵轴为下标数值对应的lowestbit(i&-i),长方形表示BIT[i]涵盖的求和的范围

可以看到每个数组下标的lowestbit(也就是图中描黑的部分)在形态上构成了一棵树的形状,这也是名称中tree的来源。并且对于每个下标的lowestbit表示成的tree node有以下特性。

(1)假如i是左子节点,那么其父节点下标为i+(lowestbit(i))
(2)假如i是右子节点,那么其父节点下标为i-(lowestbit(i))

上面这两个特性非常重要,也是我们进行后文分析的重要基础。

2. 更新一个数值
假如要修改原始数组 arr 中的下标为i的值,那么需要修改nums数组中对应下标的值。除此之外还需要修改BIT数组中涵盖了arr[i]的值。结合上图可以知道,BIT数组中涵盖了arr[i]的值为下标i及其所有父节点,伪代码如下

1
2
3
while i < n:
BIT[i] += new_value
i += (i&-i)

3. 区间求和

假如要求arr数组下标区间为[i,j]的数值之和,那么可以先求下标为[0,i-1]的数值之和,再求下标为[0,j]的数值之和,然后用后者减去前者即可。

通过观察上面初始化后的图可以知道求[0, i]可以通过下面的方法:

1
2
3
4
count = 0
while i>0:
count += BIT[i]
i -= (i&-i)

通过上面的操作,通过利用额外的两个数数组,将原来的区间求和的操作从时间复杂度$O(n)$变为了$O(logn)$,但是更新数组的值的操作的时间复杂度也从原来的$O(1)$变为了$O(logn)$,所以这种数据结构更适合用于区间求和频繁的应用场景。

下面是LeetCode上的一道利用了BIT的题目,有兴趣的读者可以尝试做一下,验证刚刚学的理论知识。

Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.

The update(i, val) function modifies nums by updating the element at index i to val.

实现的python代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class NumArray(object):
def __init__(self, nums):
"""
initialize your data structure here.
:type nums: List[int]
"""

self.nums = nums[:]
self.count = [0 for i in xrange(len(nums)+1)]
for i in xrange(len(nums)):
self.initialize(i, nums[i])

def initialize(self, i, val):
i += 1
while i < len(self.nums)+1:
self.count[i] += val
i += (i & -i)


def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: int
"""

diff = val - self.nums[i]
self.nums[i] = val
self.initialize(i, diff)

def left_sum(self, i):
i += 1
total = 0
while i>0:
total += self.count[i]
i -= (i & -i)
return total

def sumRange(self, i, j):
"""
sum of elements nums[i..j], inclusive.
:type i: int
:type j: int
:rtype: int
"""

return self.left_sum(j) - self.left_sum(i-1)


# Your NumArray object will be instantiated and called as such:
# numArray = NumArray(nums)
# numArray.sumRange(0, 1)
# numArray.update(1, 10)
# numArray.sumRange(1, 2)