likes
comments
collection
share

【综合笔试题】难度 2.5/5 :「树状数组」与「双树状数组优化」

作者站长头像
站长
· 阅读数 21

题目描述

这是 LeetCode 上的 1395. 统计作战单位数 ,难度为 中等

Tag : 「树状数组」、「容斥原理」

 n 名士兵站成一排。每个士兵都有一个 独一无二 的评分 rating

333 个士兵可以组成一个作战单位,分组规则如下:

  • 从队伍中选出下标分别为 ijk333 名士兵,他们的评分分别为 rating[i]rating[i]rating[i]rating[j]rating[j]rating[j]rating[k]rating[k]rating[k]

  • 作战单位需满足: rating[i]<rating[j]<rating[k]rating[i] < rating[j] < rating[k]rating[i]<rating[j]<rating[k] 或者 rating[i]>rating[j]>rating[k]rating[i] > rating[j] > rating[k]rating[i]>rating[j]>rating[k] ,其中  0 <=i< j< k< n0 <= i < j < k < n<=i< j< k< n

请你返回按上述条件可以组建的作战单位数量。每个士兵都可以是多个作战单位的一部分。

示例 1:

输入:rating = [2,5,3,4,1]

输出:3

解释:我们可以组建三个作战单位 (2,3,4)、(5,4,1)、(5,3,1) 。

示例 2:

输入:rating = [2,1,3]

输出:0

解释:根据题目条件,我们无法组建作战单位。

示例 3:

输入:rating = [1,2,3,4]

输出:4

提示:

  • n==rating.lengthn == rating.lengthn==rating.length
  • 3<=n<=10003 <= n <= 10003<=n<=1000
  • 1<=rating[i]<=1051 <= rating[i] <= 10^51<=rating[i]<=105
  • rating 中的元素都是唯一的

基本分析

为了方便,我们记 ratingrs

题目本质是要我们统计所有满足「递增」或「递减」的三元组。换句话说,对于每个 t=rs[i]t = rs[i]t=rs[i] 而言,我们需要统计比其 ttt 大或比 ttt 小的数的个数。

问题涉及「单点修改(更新数值 ttt 的出现次数)」以及「区间查询(查询某段范围内数的个数)」,使用「树状数组」求解较为合适。

树状数组 - 枚举两端

一个朴素的想法是,对于三元组 (i,j,k)(i, j, k)(i,j,k),我们枚举其两端 iiikkk,根据 rs[i]rs[i]rs[i]rs[k]rs[k]rs[k] 的大小关系,查询范围 [i+1,k−1][i + 1, k - 1][i+1,k1] 之间合法的数的个数。

在确定左端点 iii 时,我们从 i+1i + 1i+1 开始「从小到大」枚举右端点 kkk,并将遍历过程中经过的 rs[k]rs[k]rs[k] 添加到树状数组进行计数。

处理过程中根据 a=rs[i]a = rs[i]a=rs[i]b=rs[k]b = rs[k]b=rs[k] 的大小关系进行分情况讨论:

  • a<ba < ba<b 时,我们需要在范围 [i+1,k−1][i + 1, k - 1][i+1,k1] 中找「大于 aaa」同时「小于 bbb」的数的个数,即 query(b - 1) - query(a)
  • a>ba > ba>b 时,我们需要在范围 [i+1,k−1][i + 1, k - 1][i+1,k1] 中找「小于 aaa」同时「大于 bbb」的数的个数,即 query(a - 1) - query(b)

一些细节:显然我们需要在枚举每个左端点 iii 时清空树状数组,但注意不能使用诸如 Arrays.fill(tr, 0) 的方式进行清空。

因为在没有离散化的情况下,树状数组的大小为 m=1e5m = 1e5m=1e5,即执行 Arrays.fill 操作的复杂度为 O(m)O(m)O(m),这会导致我们计算量为至少为 n×m=1e8n \times m = 1e8n×m=1e8,会有 TLE 风险。

因此一个合适做法是:在 [i+1,n−1][i + 1, n - 1][i+1,n1] 范围内枚举完 kkk 后(进行的是 +1 计数),再枚举一次 [i+1,n−1][i + 1, n - 1][i+1,n1] 进行一次 -1 的计数进行抵消。

代码:

class Solution {
    static int N = (int)1e5 + 10;
    static int[] tr = new int[N];
    int lowbit(int x) {
        return x & -x;
    }
    void update(int x, int v) {
        for (int i = x; i < N; i += lowbit(i)) tr[i] += v;
    }
    int query(int x) {
        int ans = 0;
        for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
        return ans;
    }
    public int numTeams(int[] rs) {
        int n = rs.length, ans = 0;
        for (int i = 0; i < n; i++) {
            int a = rs[i];
            for (int j = i + 1; j < n; j++) {
                int b = rs[j];
                if (a < b) ans += query(b - 1) - query(a);
                else ans += query(a - 1) - query(b);
                update(b, 1);
            }
            for (int j = i + 1; j < n; j++) update(rs[j], -1);
        }
        return ans;
    }
}
  • 时间复杂度:令 m=1e5m = 1e5m=1e5 为值域大小,整体复杂度为 O(n2log⁡m)O(n^2\log{m})O(n2logm)
  • 空间复杂度:O(m)O(m)O(m)

双树状数组优化 - 枚举中点

我们考虑将 nnn 的数据范围提升到 1e41e41e4 该如何做。

上述解法的瓶颈在于我们枚举三元组中的左右端点,复杂度为 O(n2)O(n^2)O(n2),而实际上利用三元组必然递增或递减的特性,我们可以调整为枚举终点 jjj,从而将「枚举点对」调整为「枚举中点」,复杂度为 O(n)O(n)O(n)

假设当前枚举到的点为 rs[i]rs[i]rs[i],问题转换为在 [0,i−1][0, i - 1][0,i1] 有多少比 rs[i]rs[i]rs[i] 小/大 的数,在 [i+1,n−1][i + 1, n - 1][i+1,n1] 有多少比 rs[i]rs[i]rs[i] 大/小 的数,然后集合「乘法」原理即可知道 rs[i]rs[i]rs[i] 作为三元组中点的合法方案数。

统计 rs[i]rs[i]rs[i] 左边的比 rs[i]rs[i]rs[i] 大/小 的数很好做,只需要在「从小到大」枚举 iii 的过程中,将 rs[i]rs[i]rs[i] 添加到树状数组 tr1 即可。

对于统计 rs[i]rs[i]rs[i] 右边比 rs[i]rs[i]rs[i] 小/大 的数,则需要通过「抵消计数」来做,起始我们先将所有 rs[idx]rs[idx]rs[idx] 加入到另外一个树状数组 tr2 中(进行 +1 计数),然后在从前往后处理每个 rs[i]rs[i]rs[i] 的时候,在 tr2 中进行 -1 抵消,从而确保我们处理每个 rs[i]rs[i]rs[i] 时,tr1 存储左边的数,tr2 存储右边的数。

代码:

class Solution {
    static int N = (int)1e5 + 10;
    static int[] tr1 = new int[N], tr2 = new int[N];
    int lowbit(int x) {
        return x & -x;
    }
    void update(int[] tr, int x, int v) {
        for (int i = x; i < N; i += lowbit(i)) tr[i] += v;
    }
    int query(int[] tr, int x) {
        int ans = 0;
        for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
        return ans;
    }
    public int numTeams(int[] rs) {
        int n = rs.length, ans = 0;
        Arrays.fill(tr1, 0);
        Arrays.fill(tr2, 0);
        for (int i : rs) update(tr2, i, 1);
        for (int i = 0; i < n; i++) {
            int t = rs[i];
            update(tr2, t, -1);
            ans += query(tr1, t - 1) * (query(tr2, N - 1) - query(tr2, t));
            ans += (query(tr1, N - 1) - query(tr1, t)) * query(tr2, t - 1);
            update(tr1, t, 1);
        }
        return ans;
    }
}
  • 时间复杂度:令 m=1e5m = 1e5m=1e5 为值域大小,整体复杂度为 O(nlog⁡m)O(n\log{m})O(nlogm)
  • 空间复杂度:O(m)O(m)O(m)

最后

这是我们「刷穿 LeetCode」系列文章的第 No.1395 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。

在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。

为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:github.com/SharingSour…

在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。