树状数组

一维树状数组

这是一个单次修改、查询都为\(O(\log n)\)的数据结构,空间复杂度为 \(O(n)\)

当然有时候还是会用一些辅助数组的

树状数组 1 :单点修改,区间查询

有这样一个区间查询,首先想到的就是用左右端点的前缀和做差去维护

那么对于裸的前缀和,一次修改操作的时间复杂度为\(O(n)\)

所以想到用数据结构乱搞一下

那么就是用树状数组去维护,具体就不细说了,直接上代码

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const ll maxn = 1e6 + 10;
const ll mod = 1e9 + 7;

inline ll __read()
{
ll x(0), t(1);
char o (getchar());
while (o < '0' || o > '9') {
if (o == '-') t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

ll n, m;
ll t[maxn], d[maxn];

inline void Update(ll x, ll val)
{
ll temp(x);
while (x <= n) {
t[x] += val;
d[x] += val * temp;
x += lowbit(x);
}
}

inline ll Query(ll x)
{
ll ans(0), temp(x);
while (x) {
ans += t[x] * (temp + 1) - d[x];
x ^= lowbit(x);
}
return ans;
}

signed main()
{
n = __read(), m = __read();
ll last = 0;
for (ll i = 1; i <= n; ++i) {
ll x = __read();
Update(i, x - last);
last = x;
}
while (m--) {
ll opt = __read(), l = __read(), r = __read();
if (opt == 1) {
ll k = __read();
Update(l, k);
Update(r + 1, -k);
}
else printf ("%lld\n", Query(r) - Query(l - 1));
}
system("pause");
}

树状数组 2 :区间修改,单点查询

考虑维护一个差分序列

那么这个数的差分序列的前缀和就是这个数

如果对区间\([l,r]\)整体修改,即\(a[l]-a[l-1]\)\(a[r+1]-a[r]\)的值发生了改变

即我们只需要去修改\(cf[l]\)\(cf[r+1]\)的权值

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const int maxn = 1e6 + 10;
const int mod = 1e9 + 7;

inline int __read() {
int x(0), t(1);
char o(getchar());
while (o < '0' || o > '9') {
if (o == '-')
t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

int n, m;
ll t[maxn];

inline void Update(int x, ll val) {
while (x <= n) {
t[x] += val;
x += lowbit(x);
}
}

inline ll Query(int x) {
ll ans(0);
while (x) {
ans += t[x];
x ^= lowbit(x);
}
return ans;
}

signed main() {
n = __read(), m = __read();
int last = 0;
for (int i = 1; i <= n; ++i) {
int x = __read();
Update(i, x - last);
last = x;
}
while (m--) {
int opt = __read(), l = __read();
if (opt == 1) {
int r = __read(), k = __read();
Update(l, k);
Update(r + 1, -k);
} else
printf("%lld\n", Query(l));
}
system("pause");
}

树状数组 3 :区间修改,区间查询

按照上一个操作的思想,现在得到了差分序列\(\{cf\}\)

此时要求的是区间\([l,r]\)的和

按照差分的写法,就应该是 \[ ans= \sum_{i=l}^r\sum_{j=1}^icf[j] \] 为了方便考虑,那么就可以用\(r\)的前缀和前去\(l-1\)的前缀和

那么考虑求\([1,r]\)的和 \[ \begin{align*} ans&=\sum_{i=1}^r\sum_{j=1}^icf[j]\\ &=\sum_{i=1}^r(r-i+1)cf[i]\\ &=(r+1)\sum_{i=1}^rcf[i]-\sum_{i=1}^ri\cdot cf[i] \end{align*} \]

我们发现,这个似乎变成了两个需要维护的序列,那就写两个呗

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const ll maxn = 1e6 + 10;
const ll mod = 1e9 + 7;

inline ll __read()
{
ll x(0), t(1);
char o (getchar());
while (o < '0' || o > '9') {
if (o == '-') t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

ll n, m;
ll t[maxn], d[maxn];

inline void Update(ll x, ll val)
{
ll temp(x);
while (x <= n) {
t[x] += val;
d[x] += val * temp;
x += lowbit(x);
}
}

inline ll Query(ll x)
{
ll ans(0), temp(x);
while (x) {
ans += t[x] * (temp + 1) - d[x];
x ^= lowbit(x);
}
return ans;
}

signed main()
{
n = __read(), m = __read();
ll last = 0;
for (ll i = 1; i <= n; ++i) {
ll x = __read();
Update(i, x - last);
last = x;
}
while (m--) {
ll opt = __read(), l = __read(), r = __read();
if (opt == 1) {
ll k = __read();
Update(l, k);
Update(r + 1, -k);
}
else printf ("%lld\n", Query(r) - Query(l - 1));
}
system("pause");
}

二维树状数组

这个和一维树状数组的思路大致相同,但是每次修改、查询的复杂度都是\(O(\log_2n\times\log_2m)\)

二维树状数组 1:单点修改,区间查询

直接维护二维前缀和,简单容斥一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const int maxn = 5e3 + 10;
const int mod = 1e9 + 7;

inline int __read()
{
int x(0), t(1);
char o (getchar());
while (o < '0' || o > '9') {
if (o == '-') t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

int n, m, opt;
ll t[maxn][maxn];

inline void Update(int x, int y, int val)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
t[i][j] += val;
}

inline ll Query(int x, int y)
{
ll ans(0);
for (int i = x; i; i ^= lowbit(i))
for (int j = y; j; j ^= lowbit(j))
ans += t[i][j];
return ans;
}

signed main()
{
n = __read(), m = __read();
while (~scanf ("%d", &opt)) {
if (opt == 1) {
int a = __read(), b = __read(), x = __read();
Update(a, b, x);
} else {
int a = __read(), b = __read(), c = __read(), d = __read();
printf ("%lld\n", Query(c, d) + Query(a - 1, b - 1) - Query(a - 1, d) - Query(c, b - 1));
}
}
system("pause");
}

二维树状数组 2:区间修改,单点查询

简单地说,维护二维差分序列可以做到区间修改,直接二维前缀和就是一个单点查询

考虑如何去构造差分序列

对于每一维,可以有\(cf[i][j]=a[i][j]-a[i][j-1]\)\(cf[i][j]=a[i][j]-a[i-1][j]\)

确实,这个可以把两维分开,看成\(n\)个一维的,但是下面还有区间修改区间查询的

所以这里要说的显然不是一维的做法

考虑合并行和列的差分序列后,与二维前缀和的关系

容易发现: \[ \begin{align*} a[x][y]&=\sum_{i=1}^x\sum_{j=1}^ycf[i][j]\\ a[x-1][y]&=\sum_{i=1}^{x-1}\sum_{j=1}^ycf[i][j]\\ a[x][y-1]&=\sum_{i=1}^x\sum_{j=1}^{y-1}cf[i][j]\\ a[x-1][y-1]&=\sum_{i=1}^{x-1}\sum_{j=1}^{y-1}cf[i][j] \end{align*} \] 那么就可以轻松的得到\(cf[x][y]=a[x][y]-a[x-1][y]-a[x][y-1]+a[x-1][y-1]\)

就是说现在已经构造出了二维前缀和等于该点原本的值的差分序列了

考虑如何区间修改呢

这得画图

wXbyct.png

那么蓝色矩形就是需要修改的矩形,按照差分数组的定义,差分序列发生改变了的点就是图中四个绿色的小矩形,就可以直接改了

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const int maxn = 5e3 + 10;
const int mod = 1e9 + 7;

inline int __read()
{
int x(0), t(1);
char o (getchar());
while (o < '0' || o > '9') {
if (o == '-') t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

int n, m, opt;
ll t[maxn][maxn];

inline void Update(int x, int y, int val)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
t[i][j] += val;
}

inline ll Query(int x, int y)
{
ll ans(0);
for (int i = x; i; i ^= lowbit(i))
for (int j = y; j; j ^= lowbit(j))
ans += t[i][j];
return ans;
}

signed main()
{
n = __read(), m = __read();
while (~scanf ("%d", &opt)) {
if (opt == 1) {
int a = __read(), b = __read(), c = __read(), d = __read(), x = __read();
Update(a, b, x);
Update(a, d + 1, -x);
Update(c + 1, b, -x);
Update(c + 1, d + 1, x);

} else {
int a = __read(), b = __read();
printf ("%lld\n", Query(a, b));
}
}
system("pause");
}

二维树状数组 3:区间修改,区间查询

这个就是继承了二维树状数组2树状数组3的思想了

现在得到了二维差分序列,需要求的确实这样一个东西 \[ ans=\sum_{x=xl}^{xr}\sum_{y=yl}^{yr}\sum_{i=1}^x\sum_{j=1}^ycf[i][j] \] 丑死了。。。

还是直接考虑简单容斥,那就只剩下一个二维前缀和了 \[ \begin{align*} ans&=\sum_{x=1}^{xr}\sum_{y=1}^{yr}\sum_{i=1}^x\sum_{j=1}^ycf[i][j]\\ &=\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*(x-i+1)*(y-j+1)\\ &=\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*(xy-xj+x-iy+ij-i+y-j+1)\\ &=\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*((xy+x+y+1)-(xj-j)-(iy-i)+ij)\\ &=\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*\big((x+1)(y+1)-(x+1)j-(y+1)i+ij\big)\\ &=\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*(x+y)(y+1)-\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*(x+1)*j-\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*(y+1)*i+\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*ij\\ &=(x+y)(y+1)\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]-(x+1)\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*j-(y+1)\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*i+\sum_{i=1}^{xr}\sum_{j=1}^{yr}cf[i][j]*ij\\ \end{align*} \] 所以分别要维护的就是\(cf[i][j], cf[i][j]*j, cf[i][j]*i, cf[i][j]*ij\)四个数组了

嗯~ o( ̄▽ ̄)o,就可以愉快的跑路走人了

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const int maxn = 5e3 + 10;
const int mod = 1e9 + 7;

inline int __read()
{
int x(0), t(1);
char o (getchar());
while (o < '0' || o > '9') {
if (o == '-') t = -1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x * t;
}

int n, m, opt;
ll t[4][maxn][maxn];

inline void Update(int x, int y, int val)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j)) {
t[0][i][j] += val;
t[1][i][j] += val * y;
t[2][i][j] += val * x;
t[3][i][j] += val * x * y;
}
}

inline ll Query(int x, int y)
{
ll ans[4] = {0, 0, 0, 0};
for (int i = x; i; i ^= lowbit(i))
for (int j = y; j; j ^= lowbit(j))
for (int k(0); k < 4; ++k)
ans[k] += t[k][i][j];
return (x + 1) * (y + 1) * ans[0] - (x + 1) * ans[1] - (y + 1) * ans[2] + ans[3];
}

signed main()
{
n = __read(), m = __read();
while (~scanf ("%d", &opt)) {
if (opt == 1) {
int a = __read(), b = __read(), c = __read(), d = __read(), x = __read();
Update(a, b, x);
Update(a, d + 1, -x);
Update(c + 1, b, -x);
Update(c + 1, d + 1, x);

} else {
int a = __read(), b = __read(), c = __read(), d = __read();
printf ("%lld\n", Query(c, d) - Query(a - 1, d) - Query(c, b - 1) + Query(a - 1, b - 1));
}
}
system("pause");
}