CF997C

CF997C Sky Full of Stars

这道题的第一想法就是用总方案数减去不合法的方案,显然不合法的方案就是没有一行 / 一列 是全部相同的。

那么这就可以把行列分开讨论:

  • 没有一列合法的方案总数为:\((3^n-3)^n\)
  • 至少有一行合法且没有列合法的方案总数为:\(-\sum\limits_{i=1}^n(-1)^i\binom ni\left(3\times\left(3^{n-i}-1\right)^n+\left(3^i-3\right)\times3^{n(n-i)}\right)\)

简单说明一下,就是枚举有几行是相同的,然后又需要分类讨论了:

  • 若:这几行的颜色都相同,那么所有列的都没有一个合法的
  • 若:这几行颜色不同,那么剩下的就可以乱选了

总结一下,答案就是: \[ ans = 3^{n^2}-\left((3^n-3)^n+\sum\limits_{i=1}^n(-1)^i\binom ni\left(3\times\left(3^{n-i}-1\right)^n+\left(3^i-3\right)\times3^{n(n-i)}\right)\right) \] 所以预处理一下,整体时间复杂度就是 \(O(n\log n)\) 的了,只可惜考场没有推出来。。。

Code

Code (点击展开)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
#define per(i, l, r) for (ll i = l; i <= r; ++i)
#define rep(i, l, r) for (ll i = r; i >= l; --i)

using namespace std;

typedef long long ll;
const ll mod = 998244353;
const ll maxn = 1e6 + 10;
const char infile[] = ".in";
const char outfile[] = ".out";

inline ll read()
{
ll x(0), t(0);
char o(getchar());
while (o < '0' || o > '9') {
if (o == '-') t = 1;
o = getchar();
}
for (; o >= '0' && o <= '9'; o = getchar())
x = (x << 1) + (x << 3) + (o ^ 48);
return t ? -x : x;
}

inline void file ()
{
freopen (infile, "r", stdin);
freopen (outfile, "w", stdout);
}

inline ll Pow(ll x, ll y)
{
ll res(1);
while (y) {
if (y & 1) res = res * x % mod;
x = x * x % mod;
y >>= 1;
}
return res % mod;
}

ll ans, fac[maxn], inv[maxn], bas[maxn];

inline ll C(ll x, ll y)
{
return fac[x] * inv[y] % mod * inv[x - y] % mod;
}

int main()
{
ll n = read();
ans = (Pow(3, n * n) - Pow((Pow(3, n) - 3 + mod) % mod, n) + mod) % mod;
fac[0] = inv[0] = bas[0] = 1;
per (i, 1, n) {
fac[i] = fac[i - 1] * i % mod;
bas[i] = 3 * bas[i - 1] % mod;
}

inv[n] = Pow(fac[n], mod - 2);

rep (i, 2, n)
inv[i - 1] = inv[i] * i % mod;

per (i, 1, n) {
if (i & 1)
ans = (ans + C(n, i) * (3 * Pow(bas[n - i] - 1, n) % mod + (bas[i] - 3) * Pow(3, n * (n - i)) % mod) % mod) % mod;
else
ans = (ans - C(n, i) * (3 * Pow(bas[n - i] - 1, n) % mod + (bas[i] - 3) * Pow(3, n * (n - i)) % mod) % mod + mod) % mod;
}
cout << ans << endl;
}