[CQOI2017]小Q的表格

[CQOI2017]小Q的表格

题目描述

给定一个表格,满足:

  • \(\forall a,b\;f(a,b)=f(b,a)\)
  • \(\forall a,b\;b\cdot f(a,a+b)=(a+b)\cdot f(a,b)\)

\(m\)次操作

  • 每次操作修改\(f(a,b)\)的值,为了使整个表满足条件,所以要修改的点还挺多的
  • 然后让你输出\(k*k\)的范围内的数的和

Solution

对于性质\(\forall a,b\;b\cdot f(a,a+b)=(a+b)\cdot f(a,b)\),观察发现它可以转化一下: \[ \begin{align*} b\cdot f(a,a+b)&=(a+b)\cdot f(a,b)\\ \Leftrightarrow ab\cdot f(a,a+b)&=a(a+b)\cdot f(a,b)\\ \Leftrightarrow\quad\;\frac{f(a,a+b)}{a(a+b)}&=\frac{f(a,b)}{ab}\\ \Leftrightarrow\qquad\;\;\;\frac{f(a,b)}{a,b}&=\frac{f(\gcd(a,b), \gcd(a,b))}{\gcd^2(a,b)} \end{align*} \] 所以可以得到一个结论: \[ \begin{align*} f(a,b)&=f(gcd(a,b),gcd(a,b))\times\frac{ab}{\gcd^2(a,b)}\\ &=g(d)\times\frac{ab}{d^2} \end{align*} \] 那么就得到了一个线性的表了

那么再回头看看题目要求的问题 \[ \begin{align*} ans&=\sum_{d=1}^ng(d)\sum_{i=1}^n\sum_{j=1}^n\frac{ij}{\gcd^2(i,j)}[\gcd(i,j)=d]\\ &=\sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac nd\right\rfloor}\sum_{j=1}^{\left\lfloor\frac md\right\rfloor}ij[\gcd(i,j)=1]\\ &=\sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac nd\right\rfloor}i\sum_{j=1}^{\left\lfloor\frac nd\right\rfloor}j[\gcd(i,j)=1]\\ \end{align*} \] 考虑: \[ \sum_{i=1}^ni[\gcd(i,n)=1]=\frac{n\times\varphi(n)}{2} \] 所以原式可以化为: \[ ans= \sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac nd\right\rfloor}i^2\varphi(i) \] 那么这个时候我们发现,后面的是可以直接\(O(1)\)求解的(先预处理,就可以直接查询)

然而前面那个\(g(d)\)是会发生改变的

这个可以用树状数组或者是分块来维护

这两个理论上是可以过的,实测再loj是可以过的

洛谷似乎有点卡常,要\(TLE\)一两个点

树状数组实现分块实现

Code(树状数组)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

typedef long long ll;
const ll maxn = 4e6 + 10;
const ll mod = 1e9 + 7;

inline ll __read() {
ll x(0);
char o(getchar());
while (o < '0' || o > '9') o = getchar();
for (; o >= '0' && o <= '9'; o = getchar()) x = ((x << 1) + (x << 3) + (o ^ 48)) % mod;
return x % mod;
}

ll size, n, q, len;
ll pr[maxn], phi[maxn], f[maxn], cnt;
ll id[maxn], st[maxn], ed[maxn];
ll val[maxn], sum[maxn], Sum[maxn];
bool vis[maxn];

inline void inc(ll &x, ll y) {
x += y;
if (x >= mod)
x -= mod;
else if (x < 0)
x += mod;
}

inline ll add(ll x, ll y) {
ll temp = x + y;
if (temp >= mod)
temp -= mod;
else if (temp < 0)
temp += mod;
return temp;
}

inline ll Pow(ll x, ll y) {
ll ans(1);
while (y) {
if (y & 1)
ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans % mod;
}

ll Gcd(ll x, ll y) {
if (!y)
return x;
return Gcd(y, x % y);
}

inline void Update(ll x, ll val) {
while (x <= n) {
inc(sum[x], val);
x += lowbit(x);
}
}

inline ll Query(ll x) {
ll ans(0);
while (x) {
inc(ans, sum[x]);
x -= lowbit(x);
}
return ans % mod;
}

inline ll Query(ll l, ll r) { return add(Query(r), -1ll * Query(l - 1)); }

inline void init() {
phi[1] = 1;
for (ll i = 2; i <= n; ++i) {
if (!vis[i])
pr[++cnt] = i, phi[i] = i - 1;
for (ll j = 1; j <= cnt && i * pr[j] <= n; ++j) {
vis[i * pr[j]] = 1;
if (i % pr[j]) {
phi[i * pr[j]] = phi[i] * phi[pr[j]];
} else {
phi[i * pr[j]] = phi[i] * pr[j];
break;
}
}
}

for (ll i = 1; i <= n; ++i) f[i] = (f[i - 1] + i * i % mod * phi[i] % mod) % mod;
}

signed main() {
q = __read(), n = __read();
init();

for (ll i = 1; i <= n; ++i) {
val[i] = 1ll * i * i % mod;
Update(i, val[i]);
}

while (q--) {
ll a = __read(), b = __read(), x = __read(), k = __read(), d = Gcd(a, b);

ll upt = x * d % mod * d % mod * Pow(1ll * a * b % mod, mod - 2) % mod;
Update(d, add(upt, -val[d]));
val[d] = upt;

ll ans(0);
for (ll l = 1, r; l <= k; l = r + 1) {
r = k / (k / l);
inc(ans, f[k / l] * Query(l, r) % mod);
}

printf("%lld\n", ans);
}
}

Code(分块)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll maxn = 4e6 + 10;
const ll mod = 1e9 + 7;

inline ll __read() {
ll x(0);
char o(getchar());
while (o < '0' || o > '9') o = getchar();
for (; o >= '0' && o <= '9'; o = getchar()) x = ((x << 1) + (x << 3) + (o ^ 48)) % mod;
return x % mod;
}

ll size, n, q, len;
ll pr[maxn], phi[maxn], f[maxn], cnt;
ll id[maxn], st[maxn], ed[maxn];
ll val[maxn], sum[maxn], Sum[maxn];
bool vis[maxn];

inline void inc(ll &x, ll y) {
x += y;
if (x >= mod)
x -= mod;
}

inline ll add(ll x, ll y) {
ll temp = x + y;
if (temp >= mod)
temp -= mod;
return temp;
}

inline ll Pow(ll x, ll y) {
ll ans(1);
while (y) {
if (y & 1)
ans = ans * x % mod;
x = x * x % mod;
y >>= 1;
}
return ans % mod;
}

ll Gcd(ll x, ll y) {
if (!y)
return x;
return Gcd(y, x % y);
}

inline ll Get_sum(ll x) {
if (!x)
return 0;
return add(Sum[id[x] - 1], sum[x]);
}

inline ll Get_sum(ll l, ll r) { return ((Get_sum(r) - Get_sum(l - 1)) % mod + mod) % mod; }

inline void init() {
phi[1] = 1;
for (ll i = 2; i <= n; ++i) {
if (!vis[i])
pr[++cnt] = i, phi[i] = i - 1;
for (ll j = 1; j <= cnt && i * pr[j] <= n; ++j) {
vis[i * pr[j]] = 1;
if (i % pr[j]) {
phi[i * pr[j]] = phi[i] * phi[pr[j]];
} else {
phi[i * pr[j]] = phi[i] * pr[j];
break;
}
}
}

for (ll i = 1; i <= n; ++i) f[i] = (f[i - 1] + i * i % mod * phi[i] % mod) % mod;
}

inline void Update(ll x) {
if (x == st[id[x]])
sum[x] = val[x];
else
sum[x] = sum[x - 1] + val[x];
for (ll i = x + 1; i <= ed[id[x]]; ++i) sum[i] = add(sum[i - 1], val[i]);
for (ll i = id[x]; i <= len; ++i) Sum[i] = add(Sum[i - 1], sum[ed[i]]);
}

signed main() {
q = __read(), n = __read();
size = sqrt(n);
init();

for (ll i = 1; i <= n; ++i) {
id[i] = i / size + 1;
if (i % size)
continue;
st[id[i]] = i, ed[id[i] - 1] = i - 1;
}
len = id[n], ed[id[n]] = n;

for (ll i = 1; i <= n; ++i) val[i] = 1ll * i * i % mod;

for (ll i = 1; i <= len; ++i) {
sum[st[i]] = val[st[i]];
for (ll j = st[i] + 1; j <= ed[i]; ++j) sum[j] = add(sum[j - 1], val[j]);
Sum[i] = add(Sum[i - 1], sum[ed[i]]);
}

while (q--) {
ll a = __read(), b = __read(), x = __read(), k = __read(), d = Gcd(a, b);

val[d] = x * d % mod * d % mod * Pow(1ll * a * b % mod, mod - 2) % mod;

Update(d);

ll ans(0);
for (ll l = 1, r; l <= k; l = r + 1) {
r = k / (k / l);
inc(ans, f[k / l] * Get_sum(l, r) % mod);
}

printf("%lld\n", ans);
}
}

进一步考虑

我们每次修改的值其实是\(f(\gcd(a,b))\),所以真时要改的值并不多

  • 那么我们可以先求出原表中\(k*k\)范围内的值

  • 再求出修改的值对答案贡献相对原来的偏移量

\[ \Delta ans=\sum_{i=1}^{cnt}(改(i)-原^2(i))\sum_{j=1}^{\left\lfloor\frac ki\right\rfloor}j^2\varphi(j) \]

所以这个就是\(O(cnt)\)的复杂度

然后又是可以证明对于所有的\(\gcd(n,m)\)的个数时处于\(\log n\sim\sqrt n\)这个级别的

所以修改大概就是\(O(m\log n)\sim O(m\sqrt n)\)

然而实际跑下来要比这快得多

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const ll maxn = 4e6 + 10;
const ll mod = 1e9 + 7;

inline ll __read() {
ll x(0);
char o(getchar());
while (o < '0' || o > '9') o = getchar();
for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48);
return x;
}

ll size, n, q, len;
ll pr[maxn], phi[maxn], f[maxn], cnt;
ll stk[maxn], upt[maxn], top;
bool vis[maxn], che[maxn];

inline void init() {
phi[1] = 1;
for (ll i = 2; i <= n; ++i) {
if (!vis[i])
pr[++cnt] = i, phi[i] = i - 1;
for (ll j = 1; j <= cnt && i * pr[j] <= n; ++j) {
vis[i * pr[j]] = 1;
if (i % pr[j]) {
phi[i * pr[j]] = phi[i] * phi[pr[j]];
} else {
phi[i * pr[j]] = phi[i] * pr[j];
break;
}
}
}
for (ll i = 1; i <= n; ++i) f[i] = (f[i - 1] + i * i % mod * phi[i] % mod) % mod;
}

inline ll sum(ll x) { return x * (x + 1) / 2 % mod; }

ll gcd(ll x, ll y) {
if (!y)
return x;
return gcd(y, x % y);
}

signed main() {
q = __read(), n = __read();
init();
while (q--) {
ll a = __read(), b = __read(), x = __read(), k = __read(), d = gcd(a, b);

upt[d] = x / (a / d) / (b / d) % mod;
if (!che[d]) {
stk[++top] = d;
che[d] = 1;
}

ll ans(sum(k));
ans = ans * ans % mod;

for (ll i = 1; i <= top; ++i) {
d = stk[i];
ans = (ans + ((upt[d] - d * d % mod) % mod + mod) % mod * f[k / d] % mod) % mod;
}
printf("%lld\n", ans);
}
}