1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| #include <bits/stdc++.h> #define per(i, l, r) for (ll i = l; i <= r; ++i) #define rep(i, l, r) for (ll i = r; i >= l; --i)
using namespace std;
typedef long long ll; const ll mod = 998244353; const ll maxn = 1e6 + 10; const char infile[] = ".in"; const char outfile[] = ".out";
inline ll read() { ll x(0), t(0); char o(getchar()); while (o < '0' || o > '9') { if (o == '-') t = 1; o = getchar(); } for (; o >= '0' && o <= '9'; o = getchar()) x = (x << 1) + (x << 3) + (o ^ 48); return t ? -x : x; }
inline void file () { freopen (infile, "r", stdin); freopen (outfile, "w", stdout); }
inline ll Pow(ll x, ll y) { ll res(1); while (y) { if (y & 1) res = res * x % mod; x = x * x % mod; y >>= 1; } return res % mod; }
ll ans, fac[maxn], inv[maxn], bas[maxn];
inline ll C(ll x, ll y) { return fac[x] * inv[y] % mod * inv[x - y] % mod; }
int main() { ll n = read(); ans = (Pow(3, n * n) - Pow((Pow(3, n) - 3 + mod) % mod, n) + mod) % mod; fac[0] = inv[0] = bas[0] = 1; per (i, 1, n) { fac[i] = fac[i - 1] * i % mod; bas[i] = 3 * bas[i - 1] % mod; } inv[n] = Pow(fac[n], mod - 2);
rep (i, 2, n) inv[i - 1] = inv[i] * i % mod;
per (i, 1, n) { if (i & 1) ans = (ans + C(n, i) * (3 * Pow(bas[n - i] - 1, n) % mod + (bas[i] - 3) * Pow(3, n * (n - i)) % mod) % mod) % mod; else ans = (ans - C(n, i) * (3 * Pow(bas[n - i] - 1, n) % mod + (bas[i] - 3) * Pow(3, n * (n - i)) % mod) % mod + mod) % mod; } cout << ans << endl; }
|