题目描述
给定一个表格,满足:
\(\forall a,b\;f(a,b)=f(b,a)\)
\(\forall a,b\;b\cdot f(a,a+b)=(a+b)\cdot
f(a,b)\)
有\(m\) 次操作
每次操作修改\(f(a,b)\) 的值,为了使整个表满足条件,所以要修改的点还挺多的
然后让你输出\(k*k\) 的范围内的数的和
Solution
对于性质\(\forall a,b\;b\cdot
f(a,a+b)=(a+b)\cdot f(a,b)\) ,观察发现它可以转化一下: \[
\begin{align*}
b\cdot f(a,a+b)&=(a+b)\cdot f(a,b)\\
\Leftrightarrow ab\cdot f(a,a+b)&=a(a+b)\cdot f(a,b)\\
\Leftrightarrow\quad\;\frac{f(a,a+b)}{a(a+b)}&=\frac{f(a,b)}{ab}\\
\Leftrightarrow\qquad\;\;\;\frac{f(a,b)}{a,b}&=\frac{f(\gcd(a,b),
\gcd(a,b))}{\gcd^2(a,b)}
\end{align*}
\] 所以可以得到一个结论: \[
\begin{align*}
f(a,b)&=f(gcd(a,b),gcd(a,b))\times\frac{ab}{\gcd^2(a,b)}\\
&=g(d)\times\frac{ab}{d^2}
\end{align*}
\] 那么就得到了一个线性的表了
那么再回头看看题目要求的问题 \[
\begin{align*}
ans&=\sum_{d=1}^ng(d)\sum_{i=1}^n\sum_{j=1}^n\frac{ij}{\gcd^2(i,j)}[\gcd(i,j)=d]\\
&=\sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac
nd\right\rfloor}\sum_{j=1}^{\left\lfloor\frac
md\right\rfloor}ij[\gcd(i,j)=1]\\
&=\sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac
nd\right\rfloor}i\sum_{j=1}^{\left\lfloor\frac
nd\right\rfloor}j[\gcd(i,j)=1]\\
\end{align*}
\] 考虑: \[
\sum_{i=1}^ni[\gcd(i,n)=1]=\frac{n\times\varphi(n)}{2}
\] 所以原式可以化为: \[
ans= \sum_{d=1}^ng(d)\sum_{i=1}^{\left\lfloor\frac
nd\right\rfloor}i^2\varphi(i)
\] 那么这个时候我们发现,后面的是可以直接\(O(1)\) 求解的(先预处理,就可以直接查询)
然而前面那个\(g(d)\) 是会发生改变的
这个可以用树状数组或者是分块来维护
这两个理论上是可以过的,实测再loj 是可以过的
洛谷似乎有点卡常,要\(TLE\) 一两个点
树状数组实现 ,分块实现
Code(树状数组)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 #include <bits/stdc++.h> #define lowbit(x) (x & -x) using namespace std;typedef long long ll;const ll maxn = 4e6 + 10 ;const ll mod = 1e9 + 7 ;inline ll __read() { ll x (0 ) ; char o (getchar()) ; while (o < '0' || o > '9' ) o = getchar (); for (; o >= '0' && o <= '9' ; o = getchar ()) x = ((x << 1 ) + (x << 3 ) + (o ^ 48 )) % mod; return x % mod; } ll size, n, q, len; ll pr[maxn], phi[maxn], f[maxn], cnt; ll id[maxn], st[maxn], ed[maxn]; ll val[maxn], sum[maxn], Sum[maxn]; bool vis[maxn];inline void inc (ll &x, ll y) { x += y; if (x >= mod) x -= mod; else if (x < 0 ) x += mod; } inline ll add (ll x, ll y) { ll temp = x + y; if (temp >= mod) temp -= mod; else if (temp < 0 ) temp += mod; return temp; } inline ll Pow (ll x, ll y) { ll ans (1 ) ; while (y) { if (y & 1 ) ans = ans * x % mod; x = x * x % mod; y >>= 1 ; } return ans % mod; } ll Gcd (ll x, ll y) { if (!y) return x; return Gcd (y, x % y); } inline void Update (ll x, ll val) { while (x <= n) { inc (sum[x], val); x += lowbit (x); } } inline ll Query (ll x) { ll ans (0 ) ; while (x) { inc (ans, sum[x]); x -= lowbit (x); } return ans % mod; } inline ll Query (ll l, ll r) { return add (Query (r), -1ll * Query (l - 1 )); }inline void init () { phi[1 ] = 1 ; for (ll i = 2 ; i <= n; ++i) { if (!vis[i]) pr[++cnt] = i, phi[i] = i - 1 ; for (ll j = 1 ; j <= cnt && i * pr[j] <= n; ++j) { vis[i * pr[j]] = 1 ; if (i % pr[j]) { phi[i * pr[j]] = phi[i] * phi[pr[j]]; } else { phi[i * pr[j]] = phi[i] * pr[j]; break ; } } } for (ll i = 1 ; i <= n; ++i) f[i] = (f[i - 1 ] + i * i % mod * phi[i] % mod) % mod; } signed main () { q = __read(), n = __read(); init (); for (ll i = 1 ; i <= n; ++i) { val[i] = 1ll * i * i % mod; Update (i, val[i]); } while (q--) { ll a = __read(), b = __read(), x = __read(), k = __read(), d = Gcd (a, b); ll upt = x * d % mod * d % mod * Pow (1ll * a * b % mod, mod - 2 ) % mod; Update (d, add (upt, -val[d])); val[d] = upt; ll ans (0 ) ; for (ll l = 1 , r; l <= k; l = r + 1 ) { r = k / (k / l); inc (ans, f[k / l] * Query (l, r) % mod); } printf ("%lld\n" , ans); } }
Code(分块)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 #include <bits/stdc++.h> using namespace std;typedef long long ll;const ll maxn = 4e6 + 10 ;const ll mod = 1e9 + 7 ;inline ll __read() { ll x (0 ) ; char o (getchar()) ; while (o < '0' || o > '9' ) o = getchar (); for (; o >= '0' && o <= '9' ; o = getchar ()) x = ((x << 1 ) + (x << 3 ) + (o ^ 48 )) % mod; return x % mod; } ll size, n, q, len; ll pr[maxn], phi[maxn], f[maxn], cnt; ll id[maxn], st[maxn], ed[maxn]; ll val[maxn], sum[maxn], Sum[maxn]; bool vis[maxn];inline void inc (ll &x, ll y) { x += y; if (x >= mod) x -= mod; } inline ll add (ll x, ll y) { ll temp = x + y; if (temp >= mod) temp -= mod; return temp; } inline ll Pow (ll x, ll y) { ll ans (1 ) ; while (y) { if (y & 1 ) ans = ans * x % mod; x = x * x % mod; y >>= 1 ; } return ans % mod; } ll Gcd (ll x, ll y) { if (!y) return x; return Gcd (y, x % y); } inline ll Get_sum (ll x) { if (!x) return 0 ; return add (Sum[id[x] - 1 ], sum[x]); } inline ll Get_sum (ll l, ll r) { return ((Get_sum (r) - Get_sum (l - 1 )) % mod + mod) % mod; }inline void init () { phi[1 ] = 1 ; for (ll i = 2 ; i <= n; ++i) { if (!vis[i]) pr[++cnt] = i, phi[i] = i - 1 ; for (ll j = 1 ; j <= cnt && i * pr[j] <= n; ++j) { vis[i * pr[j]] = 1 ; if (i % pr[j]) { phi[i * pr[j]] = phi[i] * phi[pr[j]]; } else { phi[i * pr[j]] = phi[i] * pr[j]; break ; } } } for (ll i = 1 ; i <= n; ++i) f[i] = (f[i - 1 ] + i * i % mod * phi[i] % mod) % mod; } inline void Update (ll x) { if (x == st[id[x]]) sum[x] = val[x]; else sum[x] = sum[x - 1 ] + val[x]; for (ll i = x + 1 ; i <= ed[id[x]]; ++i) sum[i] = add (sum[i - 1 ], val[i]); for (ll i = id[x]; i <= len; ++i) Sum[i] = add (Sum[i - 1 ], sum[ed[i]]); } signed main () { q = __read(), n = __read(); size = sqrt (n); init (); for (ll i = 1 ; i <= n; ++i) { id[i] = i / size + 1 ; if (i % size) continue ; st[id[i]] = i, ed[id[i] - 1 ] = i - 1 ; } len = id[n], ed[id[n]] = n; for (ll i = 1 ; i <= n; ++i) val[i] = 1ll * i * i % mod; for (ll i = 1 ; i <= len; ++i) { sum[st[i]] = val[st[i]]; for (ll j = st[i] + 1 ; j <= ed[i]; ++j) sum[j] = add (sum[j - 1 ], val[j]); Sum[i] = add (Sum[i - 1 ], sum[ed[i]]); } while (q--) { ll a = __read(), b = __read(), x = __read(), k = __read(), d = Gcd (a, b); val[d] = x * d % mod * d % mod * Pow (1ll * a * b % mod, mod - 2 ) % mod; Update (d); ll ans (0 ) ; for (ll l = 1 , r; l <= k; l = r + 1 ) { r = k / (k / l); inc (ans, f[k / l] * Get_sum (l, r) % mod); } printf ("%lld\n" , ans); } }
进一步考虑
我们每次修改的值其实是\(f(\gcd(a,b))\) ,所以真时要改的值并不多
那么我们可以先求出原表中\(k*k\) 范围内的值
再求出修改的值对答案贡献相对原来的偏移量
\[
\Delta ans=\sum_{i=1}^{cnt}(改(i)-原^2(i))\sum_{j=1}^{\left\lfloor\frac
ki\right\rfloor}j^2\varphi(j)
\]
所以这个就是\(O(cnt)\) 的复杂度
然后又是可以证明对于所有的\(\gcd(n,m)\) 的个数时处于\(\log n\sim\sqrt n\) 这个级别的
所以修改大概就是\(O(m\log n)\sim O(m\sqrt
n)\)
然而实际跑下来要比这快得多
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 #include <bits/stdc++.h> using namespace std;typedef long long ll;const ll maxn = 4e6 + 10 ;const ll mod = 1e9 + 7 ;inline ll __read() { ll x (0 ) ; char o (getchar()) ; while (o < '0' || o > '9' ) o = getchar (); for (; o >= '0' && o <= '9' ; o = getchar ()) x = (x << 1 ) + (x << 3 ) + (o ^ 48 ); return x; } ll size, n, q, len; ll pr[maxn], phi[maxn], f[maxn], cnt; ll stk[maxn], upt[maxn], top; bool vis[maxn], che[maxn];inline void init () { phi[1 ] = 1 ; for (ll i = 2 ; i <= n; ++i) { if (!vis[i]) pr[++cnt] = i, phi[i] = i - 1 ; for (ll j = 1 ; j <= cnt && i * pr[j] <= n; ++j) { vis[i * pr[j]] = 1 ; if (i % pr[j]) { phi[i * pr[j]] = phi[i] * phi[pr[j]]; } else { phi[i * pr[j]] = phi[i] * pr[j]; break ; } } } for (ll i = 1 ; i <= n; ++i) f[i] = (f[i - 1 ] + i * i % mod * phi[i] % mod) % mod; } inline ll sum (ll x) { return x * (x + 1 ) / 2 % mod; }ll gcd (ll x, ll y) { if (!y) return x; return gcd (y, x % y); } signed main () { q = __read(), n = __read(); init (); while (q--) { ll a = __read(), b = __read(), x = __read(), k = __read(), d = gcd (a, b); upt[d] = x / (a / d) / (b / d) % mod; if (!che[d]) { stk[++top] = d; che[d] = 1 ; } ll ans (sum(k)) ; ans = ans * ans % mod; for (ll i = 1 ; i <= top; ++i) { d = stk[i]; ans = (ans + ((upt[d] - d * d % mod) % mod + mod) % mod * f[k / d] % mod) % mod; } printf ("%lld\n" , ans); } }