浅谈二维树状数组

Main Idea

本文主要来讲解二维树状数组的一些用法以及例题讲解。

Problems Website

T1:二维树状数组1:单点修改,区间求和(LOJ

T2:二维树状数组3:区间修改,区间求和(LOJ

T3:上帝造题的七分钟(洛谷)

Solution

我们按照操作的类型来说,必要时再解释例题。

1.基本概念

首先对于普通的一维树状数组,我之前有过粗略的介绍,不懂的同学可以先去学习一下。

一维树状数组区间修改单点查询

一维树状数组区间修改区间查询

那我们来说一下二维树状数组的定义,我们定义ta[i][j]为以(i,j)为右下角,且高度为lowbit(i),宽度为lowbit(j)的矩阵的和。

感觉还是比较好理解的。。。

2.单点修改区间查询

这是二维树状数组最基础的操作。

单点修改类似于一维树状数组,只不过是双重循环而已,而查询类似于二维前缀和,例如,输出左上角为(x1,y1),右下角为(x2,y2)矩阵的和,公式为ans=ask(x2,y2)ask(x2,y11)ask(x11,y2)+ask(x11,y11)

不懂的同学可以画个图模拟一下。

这也就是T1的题解。

代码会放在最后面。

3.区间修改单点查询

还是类似于一维树状数组此种操作,差分,仿照着二维前缀和,我们现在令ta[i][j]a[i][j]a[i1][j]+a[i][j1]a[i1][j1]的差,求值时,我们直接求前缀和即可,那么如何修改?

例如,我们要修改左上角为(x1,y1),右下角为(x2,y2)的矩阵统一加上num,我们要在(x1,y1)加上num,因为它比(x11,y1) (x1,y11) (x11,y11)多了num,要在(x2+1,y1)减去num,因为它之前多加了一个num,同理,要在(x1,y2+1)减去num,最后要在(x2+1,y2+1)加上num,因为它多减了一个num,如果我这样描述还不懂,建议画图理解。

本种操作暂未例题。。。

代码会放在后面。

4.区间修改区间查询

这就是最终的操作了,根据上文,我们不难推出式子ans=xi=1yj=1ik=1jl=1ta[k][l],怎么那么长啊。。。

我们来缩短一下,经过模拟发现,ta[1][1]被计算了x×y次,ta[1][2]被计算了x×(y1)次,ta[2][1]被计算了(x1)×y次,所以我们可以推出ta[i][j]计算了(xi+1)×(yj+1)次,于是式子就被化简成了ans=xi=1yj=1ta[i][j]×(xi+1)×(yj+1),我们再把这个式子展开,得到ans=xi=1yj=1ta[i][j]×(x+1)×(y+1)xi=1yj=1ta[i][j]×i×(y+1)xi=1yj=1ta[i][j]×j×(x+1)+xi=1yj=1ta[i][j]×i×j

那么我们只需要开四个树状数组,即:ta[i][j],ta[i][j]×i,ta[i][j]×j,ta[i][j]×i×j

以上就是T2T3的题解。

下面就是你们最期待的代码。

Code

  • T1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define re register
#define lowbit(x) (x & (-x))
typedef long long ll;
const int Maxn = 5010;
int n, m;
ll ta[Maxn][Maxn];
inline void update(int x, int y, ll z) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j))
ta[i][j] += z;
}
inline ll ask(int x, int y) {
ll res = 0;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -=lowbit(j))
res += ta[i][j];
return res;
}
inline ll query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) + ask(x1 - 1, y1 - 1) - ask(x2, y1 - 1) - ask(x1 - 1, y2);
}
int main() {
scanf("%d%d", &n, &m);
int opt;
while(scanf("%d",&opt) != EOF) {
if(opt == 1) {
int x, y;
ll k;
scanf("%d%d%lld", &x, &y, &k);
update(x, y, k);
}
else {
int a, b, c, d;
scanf("%d%d%d%d", &a, &b, &c, &d);
printf("%lld\n", query(a, b, c, d));
}
}
return 0;
}
  • 区间修改单点查询
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
inline int sc() {
int xx = 0, ff = 1; char cch = gc;
while(!isdigit(cch)) {
if(cch == '-') ff = -1; cch = gc;
}
while(isdigit(cch)) {
xx = (xx << 1) + (xx << 3) + (cch ^ '0'); cch = gc;
}
return xx * ff;
}
inline void out(int x) {
if(x < 0)
pc('-'), x = -x;
if(x >= 10)
out(x / 10);
pc(x % 10 + '0');
}
#define re register
const int Maxn = 2050;
int n, m, k;
int ta[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, int num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta[i][j] += num;
}
}
inline void add(int x1, int y1, int x2 ,int y2, int num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline int query(int x, int y) {
int res = 0;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta[i][j];
}
return res;
}
int main() {
n = sc(), m = sc(), k = sc();
for(re int i = 1; i <= n; ++i)
for(re int j = 1; j <= m; ++j) {
int x = sc();
add(i, j, i, j, x); // 初始矩阵
}
while(k--) {
int opt = sc();
if(opt == 1) { // 区间修改
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc(), num = sc();
add(x1, y1, x2, y2, num);
}
else { // 单点查询
int x = sc(), y = sc();
out(query(x, y)), pc('\n');
}
}
return 0;
}
  • T2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
typedef long long ll;
inline int sc() {
int xx = 0, ff = 1; char cch = gc;
while(!isdigit(cch)) {
if(cch == '-') ff = -1; cch = gc;
}
while(isdigit(cch)) {
xx = (xx << 1) + (xx << 3) + (cch ^ '0'); cch = gc;
}
return xx * ff;
}
inline void out(ll x) {
if(x < 0)
pc('-'), x = -x;
if(x >= 10)
out(x / 10);
pc(x % 10 + '0');
}
#define re register
const int Maxn = 2050;
int n, m;
ll ta1[Maxn][Maxn], ta2[Maxn][Maxn], ta3[Maxn][Maxn], ta4[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, ll num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta1[i][j] += num;
ta2[i][j] += num * x;
ta3[i][j] += num * y;
ta4[i][j] += num * x * y;
}
}
inline void add(int x1, int y1, int x2, int y2, ll num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline ll ask(int x, int y) {
ll res = 0LL;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta1[i][j] * (x + 1) * (y + 1) - ta2[i][j] * (y + 1) - ta3[i][j] * (x + 1) + ta4[i][j];
}
return res;
}
inline ll query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) - ask(x2, y1 - 1) - ask(x1 - 1, y2) + ask(x1 - 1, y1 - 1);
}
int main() {
n = sc(), m = sc();
int opt;
while(scanf("%d", &opt) != EOF) {
if(opt == 1) {
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc();
ll z;
scanf("%lld", &z);
add(x1, y1, x2, y2, z);
// out(query(1, 1, x2, y2)), pc('\n');
}
else {
int x1 = sc(), y1 = sc(), x2 = sc(), y2 = sc();
out(query(x1, y1, x2, y2)), pc('\n');
}
}
return 0;
}
  • T3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//Coded by Dy.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gc getchar()
#define pc(x) putchar(x)
typedef long long ll;
#define re register
const int Maxn = 2050;
int n, m;
int ta1[Maxn][Maxn], ta2[Maxn][Maxn], ta3[Maxn][Maxn], ta4[Maxn][Maxn];
#define lowbit(x) (x & (-x))
inline void update(int x, int y, int num) {
for(re int i = x; i <= n; i += lowbit(i))
for(re int j = y; j <= m; j += lowbit(j)) {
ta1[i][j] += num;
ta2[i][j] += num * x;
ta3[i][j] += num * y;
ta4[i][j] += num * x * y;
}
}
inline void add(int x1, int y1, int x2, int y2, int num) {
update(x1, y1, num);
update(x2 + 1, y2 + 1, num);
update(x2 + 1, y1, -num);
update(x1, y2 + 1, -num);
}
inline int ask(int x, int y) {
int res = 0LL;
for(re int i = x; i; i -= lowbit(i))
for(re int j = y; j; j -= lowbit(j)) {
res += ta1[i][j] * (x + 1) * (y + 1) - ta2[i][j] * (y + 1) - ta3[i][j] * (x + 1) + ta4[i][j];
}
return res;
}
inline int query(int x1, int y1, int x2, int y2) {
return ask(x2, y2) - ask(x2, y1 - 1) - ask(x1 - 1, y2) + ask(x1 - 1, y1 - 1);
}
int main() {
scanf("X %d%d", &n, &m);
char c[2];
while(scanf("%s", c) == 1) {
if(c[0] == 'L') {
int x1, x2, y1, y2, z;
scanf("%d%d%d%d%d", &x1, &y1, &x2, &y2, &z);
add(x1, y1, x2, y2, z);
}
else {
int x1, y1, x2, y2;
scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
printf("%d\n", query(x1, y1, x2, y2));
}
}
return 0;
}

rp++

Related Issues not found

Please contact @dyrisingsunlight to initialize the comment