我们从一道简单的LeetCode题开始:

303

这题相当简单,我们只需先计算出numspartial_sum

1
2
3
4
5
6
7
8
9
10
11
class NumArray {
private:
vector<long long> partial_sums;

public:
NumArray(vector<int> nums) : partial_sums({0}) {
partial_sum(nums.begin(), nums.end(), back_inserter(partial_sums));
}

int sumRange(int i, int j) { return partial_sums[j + 1] - partial_sums[i]; }
};

partial_sums[i]代表nums[0,i)之和,显然partial_sums中的两个值相见就能求得之间的数的和了。这个方法被称为前缀和0

代码很简单,效率也很好(启动开销O(n),sumRangeO(1)),然而……

307

数组可以修改了。

容易想到的方法是,在修改nums中的值的同时也修改partial_sums中的值,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class NumArray {
private:
vector<long long> partial_sums;
vector<int> m_nums;

public:
NumArray(vector<int> nums) : partial_sums({0}), m_nums(nums) {
std::partial_sum(nums.begin(), nums.end(), back_inserter(partial_sums));
}

void update(int i, int val) {
int delta = val - m_nums[i];
for (int j = i+1; j < partial_sums.size(); ++j) {
partial_sums[j] += delta;
}
m_nums[i] = val;
}

int sumRange(int i, int j) { return partial_sums[j + 1] - partial_sums[i]; }
};

确实AC了1,然而,我们可以看到,每次update的复杂度是O(n),虽然sumRange是O(1),但在updatesumRange调用次数均匀分布的情况下,平均复杂度仍是O(n)。

难道没有更好的方法了吗?

我们知道,对于大部分在数组中O(n)的操作,如果将其放在合适的二叉树中,其时间复杂度往往可以降到O(logn)。

再看一下我们原来的解决方案:partial_sums中储存的是nums[0,i)之和,每当nums中某一个值修改的时候,我们都要修改partial_sums中其后的所有值。

那么我们要做的就是把partial_sumsnums放在一棵树里,这棵树:

  • 叶子结点的值是nums数组的值
  • 其他结点的值是其左右子树的值之和

这样每次修改只需更新从这个节点到树根共计约log(n)个节点即可,同时sumRange操作却变得复杂了一些,需要从上往下查询,要花费O(log(n))的时间。

updatesumRange调用次数均匀分布的情况下,使用这棵树解这道题的平均复杂度就是O(log(n))了。

而这棵树就是传说中的线段树。

线段树的构造

线段树的构造和普通的二叉查找树想法相似,即将每个区间对半分成两个字区间,分别放置于左右子树,然后递归地进行,直至分无可分。

例如,从数组[7,23,7,8,22,26,4]中构造线段树。

数组长度为7,故根节点代表的索引范围是[0,6]

root

节点中的值应当放置数组[0,6]范围内所有元素的和,为了节约计算时间,我们先构造出左右两个子节点,计算他们的值,然后把的值之和赋给父节点。

layer1

在这个例子里,两个子节点也要等待自己的子节点构造完成。

layer2

到这里之后,左边分出了[0,0]——代表原数组的第0个元素和[1,1]——代表原数组的第1个元素。

layer3

到了这个地步,我们就能开始回朔了:

back

按照这样的套路构建线段树,我们最终能得到:

tree

这样的一棵线段树。

线段树的单点修改

线段树的单点修改和二叉搜索树的搜索方法类似,即从上往下搜索到要修改的索引,一路上同时更新节点数据的值。

比如要把这数组的第3个元素(8)改成9。

首先,我们知道3≤(0+6)/2在[0,6]的左半边,故在左子树里找,同时更新这个节点的数据。

changed1

3>(0+3)/2,故在[0,3]的右半边,故在右子树里找,同时更新节点数据。

changed2

3>(2+3)/2,故在[2,3]的右半边,故在右子树里找,同时更新数据。

changed3

最后3=3,找到了,更新即可。

changed_end

线段树的区间求和

这部分比较有意思。

算法大概来说是这样的:

从根节点开始考察:

  • 若这个节点和要求和的区间完全不沾边,直接返回0
  • 若这个节点完全在要求和的区间之内,返回这个节点的值
  • 否则对两个子节点递归地进行这一算法,并返回其和

比如上面那个线段树,我们求其[1,5]的和。

首先考察根节点,显然不符合第一和第二条,对子节点递归进行(此处以左节点为例)。

left

仍然不符合第一和第二条,递归进行:

left-left

仍然不符合,递归进行:

7

这个节点代表的范围是[0,0],和要找的范围不沾边,就返回0。

23

范围是[1,1],在要找的范围内,返回23。

返回到

left-left

此时返回的值是0+23=23。

考察右边的:

left-right

[2,3]完全在[1,5]之间,故返回节点的值16。

返回到:

left

返回23+16=39。

以此方法做完,即可得到结果。

C++实现

实现时要注意的地方是要为线段树预留4n(n为原数组长度)的空间,构造时的判空,还有一些地方的off-by-one错误。

由于线段树的大小是确定不变的,故采用了二叉树的顺序表表示。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

template<typename T>
class segment_tree {
static_assert(is_integral<T>::value,
"Only integral types can be stored in a segment_tree!");

private:
vector<T> tree;
vector<T> origin;

static inline size_t left_child_index(size_t parent_index) {
return parent_index * 2 + 1;
}

static inline size_t right_child_index(size_t parent_index) {
return parent_index * 2 + 2;
}

void construct(size_t node_index, size_t start_index, size_t end_index) {
if (start_index == end_index) {
tree[node_index] = origin[start_index];
} else {
size_t mid_index = (start_index + end_index) / 2;
construct(left_child_index(node_index), start_index, mid_index);
construct(right_child_index(node_index), mid_index + 1, end_index);
tree[node_index] = tree[left_child_index(node_index)] +
tree[right_child_index(node_index)];
}
}

T getSumForRange(size_t current_start, size_t current_end,
size_t target_start, size_t target_end,
size_t current_node_index) {
if (current_start > target_end || current_end < target_start) {
return T(0);
} else if (target_start <= current_start && current_end <= target_end) {
return tree[current_node_index];
} else {
size_t current_mid = (current_start + current_end) / 2;
return getSumForRange(current_start, current_mid, target_start,
target_end, left_child_index(current_node_index)) +
getSumForRange(current_mid + 1, current_end, target_start,
target_end, right_child_index(current_node_index));
}
}

public:
template<typename InputIter>
segment_tree(InputIter it1, InputIter it2)
: origin(it1, it2), tree(distance(it1, it2) * 4) {
if (!origin.empty())
construct(0, 0, origin.size() - 1);
}

explicit segment_tree(vector<int> &&nums) : origin(nums), tree(nums.size() * 4) {
if (!origin.empty())
construct(0, 0, origin.size() - 1);
}

void update(size_t index_in_origin, T new_num) {
assert(index_in_origin < origin.size());
T delta = new_num - origin[index_in_origin];
origin[index_in_origin] = new_num;
size_t current_index = 0;
size_t current_range_from = 0;
size_t current_range_to = origin.size() - 1;
while (current_range_from != current_range_to && left_child_index(current_index) < tree.size()) {
tree[current_index] += delta;
size_t current_range_mid = (current_range_from + current_range_to) / 2;
if (index_in_origin <= current_range_mid) {
current_range_to = current_range_mid;
current_index = left_child_index(current_index);
} else {
current_range_from = current_range_mid + 1;
current_index = right_child_index(current_index);
}
}
tree[current_index] += delta;
}

T getSum(size_t target_start, size_t target_end) {
return getSumForRange(0, origin.size() - 1, target_start, target_end, 0);
}
};
0. 二维前缀和在OpenCV 的级联分类器中也有应用。
1. LeetCode居然不卡掉我?果然这和ACM不是一个境界的东西。