浅谈二维树状数组

$\mathrm{Main \ Idea}$

本文主要来讲解二维树状数组的一些用法以及例题讲解。

$\mathrm{Problem’s \ Website}$

$\mathrm{T1}$:二维树状数组1:单点修改,区间求和($\mathbb{LOJ}$)

$\mathrm{T2}$:二维树状数组3:区间修改,区间求和($\mathbb{LOJ}$)

$\mathrm{T3}$:上帝造题的七分钟(洛谷)

$\mathrm{Solution}$

我们按照操作的类型来说,必要时再解释例题。

1.基本概念

首先对于普通的一维树状数组,我之前有过粗略的介绍,不懂的同学可以先去学习一下。

一维树状数组区间修改单点查询

一维树状数组区间修改区间查询

那我们来说一下二维树状数组的定义,我们定义$ta[i][j]$为以$(i,j)$为右下角,且高度为$lowbit(i)$,宽度为$lowbit(j)$的矩阵的和。

感觉还是比较好理解的。。。

2.单点修改区间查询

这是二维树状数组最基础的操作。

单点修改类似于一维树状数组,只不过是双重循环而已,而查询类似于二维前缀和,例如,输出左上角为$(x1,y1)$,右下角为$(x2,y2)$矩阵的和,公式为$ans = ask(x2,y2) - ask(x2,y1-1) - ask(x1-1,y2) + ask(x1-1,y1-1)$

不懂的同学可以画个图模拟一下。

这也就是$\mathrm{T1}$的题解。

代码会放在最后面。

3.区间修改单点查询

还是类似于一维树状数组此种操作,差分,仿照着二维前缀和,我们现在令$ta[i][j]$为$a[i][j]$与$a[i-1][j] + a[i][j-1] - a[i-1][j-1]$的差,求值时,我们直接求前缀和即可,那么如何修改?

例如,我们要修改左上角为$(x1,y1)$,右下角为$(x2,y2)$的矩阵统一加上$num$,我们要在$(x1,y1)$加上$num$,因为它比$(x1-1,y1)\ (x1,y1-1)\ (x1-1,y1-1)$多了$num$,要在$(x2+1,y1)$减去$num$,因为它之前多加了一个$num$,同理,要在$(x1,y2+1)$减去$num$,最后要在$(x2+1,y2+1)$加上$num$,因为它多减了一个$num$,如果我这样描述还不懂,建议画图理解。

本种操作暂未例题。。。

代码会放在后面。

4.区间修改区间查询

这就是最终的操作了,根据上文,我们不难推出式子$ans=\sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{\sum\limits_{k=1}^{i}{\sum\limits_{l=1}^{j}{ta[k][l]}}}}$,怎么那么长啊。。。

我们来缩短一下,经过模拟发现,$ta[1][1]$被计算了$x \times y$次,$ta[1][2]$被计算了$x \times (y-1)$次,$ta[2][1]$被计算了$(x-1) \times y$次,所以我们可以推出$ta[i][j]$计算了$(x-i+1)\times(y-j+1)$次,于是式子就被化简成了$ans=\sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{ta[i][j]\times(x-i+1)\times(y-j+1)}}$,我们再把这个式子展开,得到$ans=\sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{ta[i][j]\times(x+1)\times(y+1)}}-\sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{ta[i][j]\times i \times(y+1)}}-\sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{ta[i][j]\times j \times (x+1)}} + \sum\limits_{i=1}^{x}{\sum\limits_{j=1}^{y}{ta[i][j] \times i \times j}}$

那么我们只需要开四个树状数组,即:$ta[i][j],ta[i][j]\times i,ta[i][j] \times j,ta[i][j] \times i \times j$

以上就是$\mathrm{T2}$和$\mathrm{T3}$的题解。

下面就是你们最期待的代码。

$\mathrm{Code}$

  • $\mathrm{T1}$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define re register
#define lowbit(x) (x & (-x))
typedef long long ll;
const int Maxn = 5010;
int n, m;
ll ta[Maxn][Maxn];
inline void update(int x, int y, ll z) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j))
ta[i][j] += z;
}
inline ll ask(int x, int y) {
ll res = 0;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -=lowbit(j))
res += ta[i][j];
return res;
}
inline ll query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) + ask(x1 - 1, y1 - 1) - ask(x2, y1 - 1) - ask(x1 - 1, y2);
}
int main() {
scanf("%d%d", &n, &m);
int opt;
while(scanf("%d",&opt) != EOF) {
if(opt == 1) {
int x, y;
ll k;
scanf("%d%d%lld", &x, &y, &k);
update(x, y, k);
}
else {
int a, b, c, d;
scanf("%d%d%d%d", &a, &b, &c, &d);
printf("%lld\n", query(a, b, c, d));
}
}
return 0;
}
  • 区间修改单点查询
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
inline int sc() {
int xx = 0, ff = 1; char cch = gc;
while(!isdigit(cch)) {
if(cch == '-') ff = -1; cch = gc;
}
while(isdigit(cch)) {
xx = (xx << 1) + (xx << 3) + (cch ^ '0'); cch = gc;
}
return xx * ff;
}
inline void out(int x) {
if(x < 0)
pc('-'), x = -x;
if(x >= 10)
out(x / 10);
pc(x % 10 + '0');
}
#define re register
const int Maxn = 2050;
int n, m, k;
int ta[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, int num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta[i][j] += num;
}
}
inline void add(int x1, int y1, int x2 ,int y2, int num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline int query(int x, int y) {
int res = 0;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta[i][j];
}
return res;
}
int main() {
n = sc(), m = sc(), k = sc();
for(re int i = 1; i <= n; ++i)
for(re int j = 1; j <= m; ++j) {
int x = sc();
add(i, j, i, j, x); // 初始矩阵
}
while(k--) {
int opt = sc();
if(opt == 1) { // 区间修改
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc(), num = sc();
add(x1, y1, x2, y2, num);
}
else { // 单点查询
int x = sc(), y = sc();
out(query(x, y)), pc('\n');
}
}
return 0;
}
  • $\mathrm{T2}$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
typedef long long ll;
inline int sc() {
int xx = 0, ff = 1; char cch = gc;
while(!isdigit(cch)) {
if(cch == '-') ff = -1; cch = gc;
}
while(isdigit(cch)) {
xx = (xx << 1) + (xx << 3) + (cch ^ '0'); cch = gc;
}
return xx * ff;
}
inline void out(ll x) {
if(x < 0)
pc('-'), x = -x;
if(x >= 10)
out(x / 10);
pc(x % 10 + '0');
}
#define re register
const int Maxn = 2050;
int n, m;
ll ta1[Maxn][Maxn], ta2[Maxn][Maxn], ta3[Maxn][Maxn], ta4[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, ll num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta1[i][j] += num;
ta2[i][j] += num * x;
ta3[i][j] += num * y;
ta4[i][j] += num * x * y;
}
}
inline void add(int x1, int y1, int x2, int y2, ll num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline ll ask(int x, int y) {
ll res = 0LL;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta1[i][j] * (x + 1) * (y + 1) - ta2[i][j] * (y + 1) - ta3[i][j] * (x + 1) + ta4[i][j];
}
return res;
}
inline ll query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) - ask(x2, y1 - 1) - ask(x1 - 1, y2) + ask(x1 - 1, y1 - 1);
}
int main() {
n = sc(), m = sc();
int opt;
while(scanf("%d", &opt) != EOF) {
if(opt == 1) {
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc();
ll z;
scanf("%lld", &z);
add(x1, y1, x2, y2, z);
// out(query(1, 1, x2, y2)), pc('\n');
}
else {
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc();
out(query(x1, y1, x2, y2)), pc('\n');
}
}
return 0;
}
  • $\mathrm{T3}$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
typedef long long ll;
#define re register
const int Maxn = 2050;
int n, m;
int ta1[Maxn][Maxn], ta2[Maxn][Maxn], ta3[Maxn][Maxn], ta4[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, int num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta1[i][j] += num;
ta2[i][j] += num * x;
ta3[i][j] += num * y;
ta4[i][j] += num * x * y;
}
}
inline void add(int x1, int y1, int x2, int y2, int num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline int ask(int x, int y) {
int res = 0LL;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta1[i][j] * (x + 1) * (y + 1) - ta2[i][j] * (y + 1) - ta3[i][j] * (x + 1) + ta4[i][j];
}
return res;
}
inline int query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) - ask(x2, y1 - 1) - ask(x1 - 1, y2) + ask(x1 - 1, y1 - 1);
}
int main() {
scanf("X %d%d", &n, &m);
char c[2];
while(scanf("%s", c) == 1) {
if(c[0] == 'L') {
int x1, x2, y1, y2, z;
scanf("%d%d%d%d%d", &x1, &y1, &x2, &y2, &z);
add(x1, y1, x2, y2, z);
}
else {
int x1, y1, x2, y2;
scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
printf("%d\n", query(x1, y1, x2, y2));
}
}
return 0;
}

$\mathrm{rp++}$