【面试高频题】热门数据结构面试题合集(树状数组)
本文正在参加「金石计划」
树状数组
本文将对常见的关于「树状数组」的题型进行整理总结,将学习 444 道与「树状数组」相关的题目。
1395. 统计作战单位数
n
名士兵站成一排。每个士兵都有一个 独一无二 的评分 rating
。
每 333 个士兵可以组成一个作战单位,分组规则如下:
-
从队伍中选出下标分别为
i
、j
、k
的 333 名士兵,他们的评分分别为 rating[i]rating[i]rating[i]、rating[j]rating[j]rating[j]、rating[k]rating[k]rating[k] -
作战单位需满足: rating[i]<rating[j]<rating[k]rating[i] < rating[j] < rating[k]rating[i]<rating[j]<rating[k] 或者 rating[i]>rating[j]>rating[k]rating[i] > rating[j] > rating[k]rating[i]>rating[j]>rating[k] ,其中 0 <=i< j< k< n0 <= i < j < k < n0 <=i< j< k< n
请你返回按上述条件可以组建的作战单位数量。每个士兵都可以是多个作战单位的一部分。
示例 1:
输入:rating = [2,5,3,4,1]
输出:3
解释:我们可以组建三个作战单位 (2,3,4)、(5,4,1)、(5,3,1) 。
提示:
- n==rating.lengthn == rating.lengthn==rating.length
- 3<=n<=10003 <= n <= 10003<=n<=1000
- 1<=rating[i]<=1051 <= rating[i] <= 10^51<=rating[i]<=105
rating
中的元素都是唯一的
基本分析
为了方便,我们记 rating
为 rs
。
题目本质是要我们统计所有满足「递增」或「递减」的三元组。换句话说,对于每个 t=rs[i]t = rs[i]t=rs[i] 而言,我们需要统计比其 ttt 大或比 ttt 小的数的个数。
问题涉及「单点修改(更新数值 ttt 的出现次数)」以及「区间查询(查询某段范围内数的个数)」,使用「树状数组」求解较为合适。
树状数组 - 枚举两端
一个朴素的想法是,对于三元组 (i,j,k)(i, j, k)(i,j,k),我们枚举其两端 iii 和 kkk,根据 rs[i]rs[i]rs[i] 和 rs[k]rs[k]rs[k] 的大小关系,查询范围 [i+1,k−1][i + 1, k - 1][i+1,k−1] 之间合法的数的个数。
在确定左端点 iii 时,我们从 i+1i + 1i+1 开始「从小到大」枚举右端点 kkk,并将遍历过程中经过的 rs[k]rs[k]rs[k] 添加到树状数组进行计数。
处理过程中根据 a=rs[i]a = rs[i]a=rs[i] 和 b=rs[k]b = rs[k]b=rs[k] 的大小关系进行分情况讨论:
- 当 a<ba < ba<b 时,我们需要在范围 [i+1,k−1][i + 1, k - 1][i+1,k−1] 中找「大于 aaa」同时「小于 bbb」的数的个数,即
query(b - 1) - query(a)
- 当 a>ba > ba>b 时,我们需要在范围 [i+1,k−1][i + 1, k - 1][i+1,k−1] 中找「小于 aaa」同时「大于 bbb」的数的个数,即
query(a - 1) - query(b)
一些细节:显然我们需要在枚举每个左端点 iii 时清空树状数组,但注意不能使用诸如 Arrays.fill(tr, 0)
的方式进行清空。
因为在没有离散化的情况下,树状数组的大小为 m=1e5m = 1e5m=1e5,即执行 Arrays.fill
操作的复杂度为 O(m)O(m)O(m),这会导致我们计算量为至少为 n×m=1e8n \times m = 1e8n×m=1e8,会有 TLE
风险。
因此一个合适做法是:在 [i+1,n−1][i + 1, n - 1][i+1,n−1] 范围内枚举完 kkk 后(进行的是 +1
计数),再枚举一次 [i+1,n−1][i + 1, n - 1][i+1,n−1] 进行一次 -1
的计数进行抵消。
代码:
class Solution {
static int N = (int)1e5 + 10;
static int[] tr = new int[N];
int lowbit(int x) {
return x & -x;
}
void update(int x, int v) {
for (int i = x; i < N; i += lowbit(i)) tr[i] += v;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
public int numTeams(int[] rs) {
int n = rs.length, ans = 0;
for (int i = 0; i < n; i++) {
int a = rs[i];
for (int j = i + 1; j < n; j++) {
int b = rs[j];
if (a < b) ans += query(b - 1) - query(a);
else ans += query(a - 1) - query(b);
update(b, 1);
}
for (int j = i + 1; j < n; j++) update(rs[j], -1);
}
return ans;
}
}
- 时间复杂度:令 m=1e5m = 1e5m=1e5 为值域大小,整体复杂度为 O(n2logm)O(n^2\log{m})O(n2logm)
- 空间复杂度:O(m)O(m)O(m)
双树状数组优化 - 枚举中点
我们考虑将 nnn 的数据范围提升到 1e41e41e4 该如何做。
上述解法的瓶颈在于我们枚举三元组中的左右端点,复杂度为 O(n2)O(n^2)O(n2),而实际上利用三元组必然递增或递减的特性,我们可以调整为枚举终点 jjj,从而将「枚举点对」调整为「枚举中点」,复杂度为 O(n)O(n)O(n)。
假设当前枚举到的点为 rs[i]rs[i]rs[i],问题转换为在 [0,i−1][0, i - 1][0,i−1] 有多少比 rs[i]rs[i]rs[i] 小/大 的数,在 [i+1,n−1][i + 1, n - 1][i+1,n−1] 有多少比 rs[i]rs[i]rs[i] 大/小 的数,然后集合「乘法」原理即可知道 rs[i]rs[i]rs[i] 作为三元组中点的合法方案数。
统计 rs[i]rs[i]rs[i] 左边的比 rs[i]rs[i]rs[i] 大/小 的数很好做,只需要在「从小到大」枚举 iii 的过程中,将 rs[i]rs[i]rs[i] 添加到树状数组 tr1
即可。
对于统计 rs[i]rs[i]rs[i] 右边比 rs[i]rs[i]rs[i] 小/大 的数,则需要通过「抵消计数」来做,起始我们先将所有 rs[idx]rs[idx]rs[idx] 加入到另外一个树状数组 tr2
中(进行 +1
计数),然后在从前往后处理每个 rs[i]rs[i]rs[i] 的时候,在 tr2
中进行 -1
抵消,从而确保我们处理每个 rs[i]rs[i]rs[i] 时,tr1
存储左边的数,tr2
存储右边的数。
代码:
class Solution {
static int N = (int)1e5 + 10;
static int[] tr1 = new int[N], tr2 = new int[N];
int lowbit(int x) {
return x & -x;
}
void update(int[] tr, int x, int v) {
for (int i = x; i < N; i += lowbit(i)) tr[i] += v;
}
int query(int[] tr, int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
public int numTeams(int[] rs) {
int n = rs.length, ans = 0;
Arrays.fill(tr1, 0);
Arrays.fill(tr2, 0);
for (int i : rs) update(tr2, i, 1);
for (int i = 0; i < n; i++) {
int t = rs[i];
update(tr2, t, -1);
ans += query(tr1, t - 1) * (query(tr2, N - 1) - query(tr2, t));
ans += (query(tr1, N - 1) - query(tr1, t)) * query(tr2, t - 1);
update(tr1, t, 1);
}
return ans;
}
}
- 时间复杂度:令 m=1e5m = 1e5m=1e5 为值域大小,整体复杂度为 O(nlogm)O(n\log{m})O(nlogm)
- 空间复杂度:O(m)O(m)O(m)
307. 区域和检索 - 数组可修改
给你一个数组 nums
,请你完成两类查询,其中一类查询要求更新数组下标对应的值,另一类查询要求返回数组中某个范围内元素的总和。
实现 NumArray
类:
NumArray(int[] nums)
用整数数组nums
初始化对象void update(int index, int val)
将nums[index]
的值更新为val
int sumRange(int left, int right)
返回子数组nums[left, right]
的总和(即,nums[left] + nums[left + 1], ..., nums[right]
)
示例:
输入:
["NumArray", "sumRange", "update", "sumRange"]
[[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]
输出:
[null, 9, null, 8]
解释:
NumArray numArray = new NumArray([1, 3, 5]);
numArray.sumRange(0, 2); // 返回 9 ,sum([1,3,5]) = 9
numArray.update(1, 2); // nums = [1,2,5]
numArray.sumRange(0, 2); // 返回 8 ,sum([1,2,5]) = 8
提示:
- 1<=nums.length<=3∗1041 <= nums.length <= 3 * 10^41<=nums.length<=3∗104
- −100<=nums[i]<=100-100 <= nums[i] <= 100−100<=nums[i]<=100
- 0<=index<nums.length0 <= index < nums.length0<=index<nums.length
- −100<=val<=100-100 <= val <= 100−100<=val<=100
- 0<=left<=right<nums.length0 <= left <= right < nums.length0<=left<=right<nums.length
- 最多调用 3∗1043 * 10^43∗104 次
update
和sumRange
方法
解题思路
这是一道很经典的题目,通常还能拓展出一大类问题。
针对不同的题目,我们有不同的方案可以选择(假设我们有一个数组):
- 数组不变,求区间和:「前缀和」、「树状数组」、「线段树」
- 多次修改某个数(单点),求区间和:「树状数组」、「线段树」
- 多次修改某个区间,输出最终结果:「差分」
- 多次修改某个区间,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
- 多次将某个区间变成同一个数,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
这样看来,「线段树」能解决的问题是最多的,那我们是不是无论什么情况都写「线段树」呢?
答案并不是,而且恰好相反,只有在我们遇到第 4/5 类问题,不得不写「线段树」的时候,我们才考虑线段树。
因为「线段树」代码很长,而且常数很大,实际表现不算很好。我们只有在不得不用的时候才考虑「线段树」。
总结一下,我们应该按这样的优先级进行考虑:
- 简单求区间和,用「前缀和」
- 多次将某个区间变成同一个数,用「线段树」
- 其他情况,用「树状数组」
树状数组
本题只涉及「单点修改」和「区间求和」,属于「树状数组」的经典应用。
「树状数组」本身是一个很简单的数据结构,但是要搞懂其为什么可以这样「查询」&「更新」还是比较困难的(特别是为什么可以这样更新),往往需要从「二进制分解」进行出发理解。
树状数组涉及的操作有两个,复杂度均为 O(logn)O(\log{n})O(logn):
void add(int x, int u)
:含义为在 xxx 的位置增加 uuu(注意位置下标从 111 开始);int query(int x)
:含义为查询从 [1,x][1, x][1,x] 区间的和为多少(配合容斥原理,可实现任意区间查询)。
代码:
class NumArray {
int[] tr;
int lowbit(int x) {
return x & -x;
}
void add(int x, int u) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += u;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
int[] nums;
int n;
public NumArray(int[] _nums) {
nums = _nums;
n = nums.length;
tr = new int[n + 10];
for (int i = 0; i < n; i++) add(i + 1, nums[i]);
}
public void update(int index, int val) {
add(index + 1, val - nums[index]);
nums[index] = val;
}
public int sumRange(int left, int right) {
return query(right + 1) - query(left);
}
}
- 时间复杂度:插入和查询复杂度均为 O(logn)O(\log{n})O(logn)
- 空间复杂度:O(n)O(n)O(n)
线段树
相比「树状数组」,另外一个更为进阶且通用的做法是使用「线段树」。
线段树的所有操作同样为 O(logn),O(\log{n}),O(logn),由于本题不涉及「区间修改」操作,因此我们的线段树只需要实现 pushup
操作(子节点往上更新父节点),而不需要实现用于懒标记的 pushdown
操作(父节点往下传递「更新」的操作)。
关于线段树设计的几种操作:
void build(int u, int l, int r)
:含义为从编号为 uuu 的节点开始,构造范围为 [l,r][l,r][l,r] 的树节点;void update(int u, int x, int v)
:含义为从编号为 uuu 的节点开始,在 xxx 位置增加 vvv;- 更具一般性(涉及区间修改)的操作应该为
void update(int u, int l, int r, int v)
,代表在 [l,r][l, r][l,r] 范围增加 vvv;
- 更具一般性(涉及区间修改)的操作应该为
int query(int u, int l, int r)
:含义为从编号为 uuu 的节点开始,查询 [l,r][l, r][l,r] 区间和为多少。
注意:对于编号为
u
的节点而言,其左子节点的编号为u << 1
,其右节点的编号为u << 1 | 1
。
代码(考虑为线段树增加 static
优化的代码见 P2P2P2,样例个数较少,优化不明显):
class NumArray {
Node[] tr;
class Node {
int l, r, v;
Node(int _l, int _r) {
l = _l; r = _r;
}
}
void build(int u, int l, int r) {
tr[u] = new Node(l, r);
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void update(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) {
tr[u].v += v;
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) update(u << 1, x, v);
else update(u << 1 | 1, x, v);
pushup(u);
}
int query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
int ans = 0;
if (l <= mid) ans += query(u << 1, l, r);
if (r > mid) ans += query(u << 1 | 1, l, r);
return ans;
}
void pushup(int u) {
tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v;
}
int[] nums;
public NumArray(int[] _nums) {
nums = _nums;
int n = nums.length;
tr = new Node[n * 4];
build(1, 1, n);
for (int i = 0; i < n; i++) update(1, i + 1, nums[i]);
}
public void update(int index, int val) {
update(1, index + 1, val - nums[index]);
nums[index] = val;
}
public int sumRange(int left, int right) {
return query(1, left + 1, right + 1);
}
}
- 时间复杂度:插入和查询复杂度均为 O(logn)O(\log{n})O(logn)
- 空间复杂度:O(n)O(n)O(n)
327. 区间和的个数
给你一个整数数组 nums
以及两个整数 lower
和 upper
。求数组中,值位于范围 [lower,upper][lower, upper][lower,upper] (包含 lower
和 upper
)之内的 区间和的个数 。
区间和 S(i,j)S(i, j)S(i,j) 表示在 nums
中,位置从 iii 到 jjj 的元素之和,包含 iii 和 jjj (i ≤ j
)。
示例 1:
输入:nums = [-2,5,-1], lower = -2, upper = 2
输出:3
解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。
示例 2:
输入:nums = [0], lower = 0, upper = 0
输出:1
提示:
- 1<=nums.length<=1051 <= nums.length <= 10^51<=nums.length<=105
- −231<=nums[i]<=231−1-2^{31} <= nums[i] <= 2^{31} - 1−231<=nums[i]<=231−1
- −105<=lower<=upper<=105-10^5 <= lower <= upper <= 10^5−105<=lower<=upper<=105
- 题目数据保证答案是一个 323232 位 的整数
树状数组(离散化)
由于区间和的定义是子数组的元素和,容易想到「前缀和」来快速求解。
对于每个 nums[i]nums[i]nums[i] 而言,我们需要统计以每个 nums[i]nums[i]nums[i] 为右端点的合法子数组个数(合法子数组是指区间和值范围为 [lower,upper][lower, upper][lower,upper] 的子数组)。
我们可以从前往后处理 numsnumsnums,假设当前我们处理到位置 kkk,同时下标 [0,k][0, k][0,k] 的前缀和为 sss,那么以 nums[k]nums[k]nums[k] 为右端点的合法子数组个数,等价于在下标 [0,k−1][0, k - 1][0,k−1] 中前缀和范围在 [s−upper,s−lower][s - upper, s - lower][s−upper,s−lower] 的数的个数。
我们需要使用一个数据结构来维护「遍历过程中的前缀和」,每遍历 nums[i]nums[i]nums[i] 需要往数据结构加一个数,同时每次需要查询值在某个范围内的数的个数。涉及的操作包括「单点修改」和「区间查询」,容易想到使用树状数组进行求解。
但值域的范围是巨大的(同时还有负数域),我们可以利用 numsnumsnums 的长度为 10510^5105 来做离散化。我们需要考虑用到的数组都有哪些:
- 首先前缀和数组中的每一位 sss 都需要被用到(添加到树状数组中);
- 同时对于每一位 nums[i]nums[i]nums[i](假设对应的前缀和为 sss),我们都需要查询以其为右端点的合法子数组个数,即查询前缀和范围在 [s−upper,s−lower][s - upper, s - lower][s−upper,s−lower] 的数的个数。
因此对于前缀和数组中的每一位 sss,我们用到的数有 sss、s−uppers - uppers−upper 和 s−lowers - lowers−lower 三个数字,共有 1e51e51e5 个 sss,即最多共有 3×1053 \times 10^53×105 个不同数字被使用,我们可以对所有用到的数组进行排序编号(离散化),从而将值域大小控制在 3×1053 \times 10^53×105 范围内。
代码:
class Solution {
int m;
int[] tr = new int[100010 * 3];
int lowbit(int x) {
return x & -x;
}
void add(int x, int v) {
for (int i = x; i <= m; i += lowbit(i)) tr[i] += v;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
public int countRangeSum(int[] nums, int lower, int upper) {
Set<Long> set = new HashSet<>();
long s = 0;
set.add(s);
for (int i : nums) {
s += i;
set.add(s);
set.add(s - lower);
set.add(s - upper);
}
List<Long> list = new ArrayList<>(set);
Collections.sort(list);
Map<Long, Integer> map = new HashMap<>();
for (long x : list) map.put(x, ++m);
s = 0;
int ans = 0;
add(map.get(s), 1);
for (int i : nums) {
s += i;
int a = map.get(s - lower), b = map.get(s - upper) - 1;
ans += query(a) - query(b);
add(map.get(s), 1);
}
return ans;
}
}
- 时间复杂度:去重离散化的复杂度为 O(nlogn)O(n\log{n})O(nlogn);统计答案的复杂度为 O(nlogn)O(n\log{n})O(nlogn)
- 空间复杂度:O(n)O(n)O(n)
215. 数组中的第K个最大元素
给定整数数组 nums
和整数 k
,请返回数组中第 k
个最大的元素。
请注意,你需要找的是数组排序后的第 k
个最大的元素,而不是第 k
个不同的元素。
你必须设计并实现时间复杂度为 O(n)O(n)O(n) 的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4], k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6], k = 4
输出: 4
提示:
- 1<=k<=nums.length<=1051 <= k <= nums.length <= 10^51<=k<=nums.length<=105
- −104 <=nums[i]<=104-10^4 <= nums[i] <= 10^4−104 <=nums[i]<=104
值域映射 + 树状数组 + 二分
除了直接对数组进行排序,取第 kkk 位的 O(nlogn)O(n\log{n})O(nlogn) 做法以外。
对于值域大小 小于 数组长度本身时,我们还能使用「树状数组 + 二分」的 O(nlogm)O(n\log{m})O(nlogm) 做法,其中 mmm 为值域大小。
首先值域大小为 [−104,104][-10^4, 10^4][−104,104],为了方便,我们为每个 nums[i]nums[i]nums[i] 增加大小为 1e4+101e4 + 101e4+10 的偏移量,将值域映射到 [10,2×104+10][10, 2 \times 10^4 + 10][10,2×104+10] 的空间。
将每个增加偏移量后的 nums[i]nums[i]nums[i] 存入树状数组,考虑在 [0,m)[0, m)[0,m) 范围内进行二分,假设我们真实第 kkk 大的值为 ttt,那么在以 ttt 为分割点的数轴上,具有二段性质:
- 在 [0,t][0, t][0,t] 范围内的数 curcurcur 满足「树状数组中大于等于 curcurcur 的数不低于 kkk 个」
- 在 (t,m)(t, m)(t,m) 范围内的数 curcurcur 不满足「树状数组中大于等于 curcurcur 的数不低于 kkk 个」
二分出结果后再减去刚开始添加的偏移量即是答案。
代码:
class Solution {
int M = 10010, N = 2 * M;
int[] tr = new int[N];
int lowbit(int x) {
return x & -x;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
void add(int x) {
for (int i = x; i < N; i += lowbit(i)) tr[i]++;
}
public int findKthLargest(int[] nums, int k) {
for (int x : nums) add(x + M);
int l = 0, r = N - 1;
while (l < r) {
int mid = l + r + 1 >> 1;
if (query(N - 1) - query(mid - 1) >= k) l = mid;
else r = mid - 1;
}
return r - M;
}
}
- 时间复杂度:将所有数字放入树状数组复杂度为 O(nlogm)O(n\log{m})O(nlogm);二分出答案复杂度为 O(log2m)O(\log^2{m})O(log2m),其中 m=2×104m = 2 \times 10^4m=2×104 为值域大小。整体复杂度为 O(nlogm)O(n\log{m})O(nlogm)
- 空间复杂度:O(m)O(m)O(m)
优先队列(堆)
另外一个容易想到的想法是利用优先队列(堆),由于题目要我们求的是第 kkk 大的元素,因此我们建立一个小根堆。
根据当前队列元素个数或当前元素与栈顶元素的大小关系进行分情况讨论:
- 当优先队列元素不足 kkk 个,可将当前元素直接放入队列中;
- 当优先队列元素达到 kkk 个,并且当前元素大于栈顶元素(栈顶元素必然不是答案),可将当前元素放入队列中。
代码:
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> q = new PriorityQueue<>((a,b)->a-b);
for (int x : nums) {
if (q.size() < k || q.peek() < x) q.add(x);
if (q.size() > k) q.poll();
}
return q.peek();
}
}
- 时间复杂度:O(nlogk)O(n\log{k})O(nlogk)
- 空间复杂度:O(k)O(k)O(k)
快速选择
对于给定数组,求解第 kkk 大元素,且要求线性复杂度,正解为使用「快速选择」做法。
基本思路与「快速排序」一致,每次敲定一个基准值 x
,根据当前与 x
的大小关系,将范围在 [l,r][l, r][l,r] 的 nums[i]nums[i]nums[i] 划分为到两边。
同时利用,利用题目只要求输出第 kkk 大的值,而不需要对数组进行整体排序,我们只需要根据划分两边后,第 kkk 大数会落在哪一边,来决定对哪边进行递归处理即可。
快速排序模板为面试向重点内容,需要重要掌握。
代码:
class Solution {
int[] nums;
int qselect(int l, int r, int k) {
if (l == r) return nums[k];
int x = nums[l], i = l - 1, j = r + 1;
while (i < j) {
do i++; while (nums[i] < x);
do j--; while (nums[j] > x);
if (i < j) swap(i, j);
}
if (k <= j) return qselect(l, j, k);
else return qselect(j + 1, r, k);
}
void swap(int i, int j) {
int c = nums[i];
nums[i] = nums[j];
nums[j] = c;
}
public int findKthLargest(int[] _nums, int k) {
nums = _nums;
int n = nums.length;
return qselect(0, n - 1, n - k);
}
}
- 时间复杂度:期望 O(n)O(n)O(n)
- 空间复杂度:忽略递归带来的额外空间开销,复杂度为 O(1)O(1)O(1)
总结
相比于线段树而言,树状数组属于短小精悍类的数据结构,属于区间求和问题中的利器。
与线段树一样,数组数组同样支持 log\loglog 级别复杂度操作,同时有着更低的常数。