Binary Indexed Trees 简介
Binary Indexed Trees(中文名为树状数组,下文简称为 BIT)是一种特殊的数据结构,可多用于高效计算数列的前缀和, 区间和。对于长度为 n 的数组,它可以以 \(O(logn)\) 的时间得到任意前缀和 $ {_{i=1}^{j} a [i],1<=j<=N}$,并同时支持在 $ O (log n)$ 时间内支持动态单点值的修改。空间复杂度 \(O(n)\)
虽然 BIT 名称中带有 tree 这个词,但是实际存储时是利用两个数组进行存储,记这两个数组为 nums
和 BIT
。假设我们现在需要对原始数组 arr
进行前缀求和和区间求和,那么可以按照以下步骤进行。
1. 初始化
\(nums[i] = arr[i]\) \(BIT[i] = {\displaystyle \sum _{k=i-lowestbit(i)+1}^{i}arr[k]}\)
上面的 lowestbit(i)
指将 i 转为二进制后,最后一个 1 的位置所代表的数值。如 lowestbit(1)=1、lowestbit(6)=2
,具体的实现可通过 (i&-i)
获取。
下图就是初始化后的情况,横轴为数组的下标 (记为 i),纵轴为下标数值对应的 lowestbit(i&-i),长方形表示 BIT [i] 涵盖的求和的范围
[][1]
可以看到每个数组下标的 lowestbit(也就是图中描黑的部分)在形态上构成了一棵树的形状,这也是名称中 tree
的来源。并且对于每个下标的 lowestbit 表示成的 tree node 有以下特性。
(1) 假如 i 是左子节点,那么其父节点下标为 i+(lowestbit (i)) (2) 假如 i 是右子节点,那么其父节点下标为 i-(lowestbit (i))
上面这两个特性非常重要,也是我们进行后文分析的重要基础。
2. 更新一个数值 假如要修改原始数组 arr
中的下标为 i 的值,那么需要修改 nums
数组中对应下标的值。除此之外还需要修改 BIT 数组中涵盖了 arr[i]
的值。结合上图可以知道,BIT 数组中涵盖了 arr[i]
的值为下标 i 及其所有父节点,伪代码如下1
2
3while i < n:
BIT[i] += new_value
i += (i&-i)
3. 区间求和
假如要求 arr 数组下标区间为 [i,j] 的数值之和,那么可以先求下标为 [0,i-1] 的数值之和,再求下标为 [0,j] 的数值之和,然后用后者减去前者即可。
通过观察上面初始化后的图可以知道求 [0, i] 可以通过下面的方法:1
2
3
4count = 0
while i>0:
count += BIT[i]
i -= (i&-i)
通过上面的操作,通过利用额外的两个数数组,将原来的区间求和的操作从时间复杂度 \(O(n)\) 变为了 \(O(logn)\), 但是更新数组的值的操作的时间复杂度也从原来的 \(O(1)\) 变为了 \(O(logn)\), 所以这种数据结构更适合用于区间求和频繁的应用场景。
下面是 [LeetCode][2] 上的一道利用了 BIT 的题目,有兴趣的读者可以尝试做一下,验证刚刚学的理论知识。 >Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
实现的 python 代码如下所示:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51class NumArray(object):
def __init__(self, nums):
"""
initialize your data structure here.
:type nums: List[int]
"""
self.nums = nums[:]
self.count = [0 for i in xrange(len(nums)+1)]
for i in xrange(len(nums)):
self.initialize(i, nums[i])
def initialize(self, i, val):
i += 1
while i < len(self.nums)+1:
self.count[i] += val
i += (i & -i)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: int
"""
diff = val - self.nums[i]
self.nums[i] = val
self.initialize(i, diff)
def left_sum(self, i):
i += 1
total = 0
while i>0:
total += self.count[i]
i -= (i & -i)
return total
def sumRange(self, i, j):
"""
sum of elements nums[i..j], inclusive.
:type i: int
:type j: int
:rtype: int
"""
return self.left_sum(j) - self.left_sum(i-1)
# Your NumArray object will be instantiated and called as such:
# numArray = NumArray(nums)
# numArray.sumRange(0, 1)
# numArray.update(1, 10)
# numArray.sumRange(1, 2)