본문 바로가기

Problem Solving

백준 11618번 Frightful Formula

반응형

문제 요약

F라는 N by N 2차원 배열에서 F[N][N]을 구해야 한다. 배열의 1열과 1행은 전체가 주어진다. F[i][j]는 아래와 같이 정의된다.
$$
F[i][j]=aF[i][j-1]+bF[i-1][j]+c
$$

풀이

일일이 다 구하는 것은 시간초과가 나니까 논외다. 주어진 점화식에서 상수항이 좀 귀찮으니 상수항을 뗀 상태로 생각해보자.

1행과 1열이 다 주어지는데 1행과 1열의 각 원소가 상수항을 떼냈을 때 F[N][N]에 미치는 영향이 어느 정도인지를 한 번 보자.

0 1 0 0 0
0 b ab a2b a3b
0 b2 2ab2 3a2b2 4a3b2
0 b3 3ab3 6a2b3 10a3b3
0 b4 4ab4 10a2b4 20a3b4

(1,2)에 1이 있을 때 그 1이 각 셀에 끼치는 영향을 나타낸 것이다.

우리가 관심있는 것은 각 셀에 대한 영향보다는 F[N][N]에 대한 영향이다. 위에서 한 일을 몇 번 하다보면 $(1,k)$에 있는 숫자가 $x$라고 했을 때 $(N,N)$에 끼치는 영향을 아래와 같은 수식으로 정리가 가능하다.
$$
\binom{2N-k-2}{N-k}a^{N-k}b^{N-1}F(1,k)
$$

그리고 똑같은 과정을 $F(k,1)$에 있는 숫자들로 반복하면 $F(k,1)$이 $F(N,N)$에 미치는 영향도 수식으로 정리할 수 있다.
$$
\binom{2N-k-2}{N-k}a^{N-1}B^{N-k}F(k,1)
$$

따라서, 점화식에서 상수항을 제거 했을 때 $F(N,N)$은 아래와 같이 표현할 수 있다.

$$
F(N,N) = \sum_{k=2}^N{\binom{2N-k-2}{N-k}a^{N-k}b^{N-1}F(1,k)+\binom{2N-k-2}{N-k}a^{N-1}B^{N-k}F(k,1)}
$$

이제 상수항 $c$를 고려해보자. $F(i,j)$를 구할 때 $c$가 더해지는 것은 $F(i,j)$에 원래부터 $c$가 위치해 있었고 상수항을 없앤 점화식을 적용하는 것과 같은 과정이라고 볼 수 있다.

0 0 0 0 0
0 c ac a2c a3c
0 bc 2abc 3a2bc 4a3bc
0 b2c 3ab2c 6a2b2c 10a3b2c
0 b3c 4ab3 10a2b3c 20a3b3c

이 과정도 몇 번 반복하다보면 $F(i,j)$에 있는 $c$가 $F(N,N)$에 얼마나 영향을 끼치는지를 수식으로 표현이 가능하다.
$$
\binom{2N-i-j}{N-i}a^{N-j}b^{N-i}c
$$
결과적으로 $F(N,N)$은 아래와 같이 표현할 수 있다.
$$
\small{F(N,N) = \sum_{k=2}^N{\binom{2N-k-2}{N-k}a^{N-k}b^{N-1}F(1,k)+\binom{2N-k-2}{N-k}a^{N-1}B^{N-k}F(k,1)} + \sum_{i=2}^{N}\sum_{j=2}^{N}\binom{2N-i-j}{N-i}a^{N-j}b^{N-i}c}
$$

앞의 항은 팩토리얼과 그 역원, 그리고 a,b의 거듭제곱들을 구하는 것으로 충분하므로 $O(N)$에 구하는 것이 가능하다. 뒤 항이 나이브하게 구하면 O(N^2^)이 되버리는데 식을 조금 변형하면 Convolution 형태로 바뀐다.
$$
\begin{aligned}
c\sum_{i=2}^{N}\sum_{j=2}^{N}\binom{2N-i-j}{N-i}a^{N-j}b^{N-i} &= c\sum_{i=2}^{N}\sum_{j=2}^{N}\frac{(2N-i-j)!}{(N-i)!(N-j)!}a^{N-j}b^{N-i} \\
&=c\sum_{i=2}^{N}\sum_{j=2}^{N}(2N-i-j)!\frac{a^{N-j}}{(N-j)!}\frac{b^{N-i}}{(N-i)!}
\end{aligned}
$$
$F_i = (2N-i)!$, $g_i = \frac{a^{N-i}}{(N-i)!}$ , $h_i=\frac{b^{N-i}}{(N-i)!}$, $G=g*h$로 $g$와 $h$의 convolution이라고 설정하자. 그러면 위의 식이 아래로 바뀐다.
$$
c\sum_{i=4}^{2N}F_iG_i
$$

$g, h$를 만드는 데에 $O(N)$만큼 걸리고 $G$를 구하는 데에 $O(NlogN)$이 걸린다. 이걸로 문제를 해결이 가능하다.

여담이지만 문제 이름을 참 문제에 맞게 잘지었다.

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using cdbl = complex<double>;
const int MAXN = 200000 * 2 + 5;
const ll mod = 1000003;
ll A,B,C,N;
ll fact[MAXN], inv_fact[MAXN];
ll L[200005], T[200005];
ll ans;

ll modpow(ll a, ll n) {
    ll ret = 1;
    while(n) {
        if(n&1) ret = (ret * a) % mod;
        n >>= 1; a = (a * a) % mod;
    }
    return ret;
}

void pre() {
    fact[0] = 1; inv_fact[0] = 1;
    for(ll i=1;i<MAXN;++i) {
        fact[i] = (fact[i-1] * i) % mod;
        inv_fact[i] = modpow(fact[i], mod-2);
    }
}

ll bino(int n, int r) {
    return (fact[n] * ((inv_fact[r] * inv_fact[n-r]) % mod)) % mod;
}

void fft(vector<cdbl> &a, bool inv) {
    int n = a.size();
        vector<cdbl> w(n/2), aux(n);
        for(int i=0; i<n/2; i++){
            int k = i&-i;
            if(i == k){
                double ang = 2 * M_PI * i / n;
                if(inv) ang *= -1;
                w[i] = cdbl(cos(ang), sin(ang));
            }
            else w[i] = w[i-k] * w[k];
        }
        for(int i=n/2; i; i>>=1){
            aux = a;
            for(int k=0; 2*k<n; k+=i){
                for(int j=0; j<i; j++){
                    cdbl u = aux[2*k + j], v = aux[2*k + j + i] * w[k];
                    a[k + j] = u + v;
                    a[k + j + n/2] = u - v;
                }
            }
        }
        if(inv){
            for(int i=0; i<n; i++){
                a[i] /= n;
            }
        }
}

vector<ll> multiply(vector<ll> &f, vector<ll> &g, ll mod) {
    int n = 1;
    while(n < max(f.size(), g.size())) n <<= 1;
    n <<= 1;
    int shift = 15, mask = (1 << shift) - 1;
    vector<cdbl> F(n,cdbl(0,0)), G(n,cdbl(0,0));
    for(int i=0;i<f.size();++i) F[i] = cdbl(f[i] >> shift, f[i] & mask);
    for(int i=0;i<g.size();++i) G[i] = cdbl(g[i] >> shift, g[i] & mask);
    fft(F, false); fft(G, false);
    vector<cdbl> f1g1_f1g2(n), f2g1_f2g2(n);
    for(int i=0;i<n;++i) {
        cdbl f1 = (F[i] + conj(F[(n-i)%n])) * cdbl(0.5, 0);
        cdbl f2 = (F[i] - conj(F[(n-i)%n])) * cdbl(0, -0.5);
        cdbl g1 = (G[i] + conj(G[(n-i)%n])) * cdbl(0.5, 0);
        cdbl g2 = (G[i] - conj(G[(n-i)%n])) * cdbl(0, -0.5);
        f1g1_f1g2[i] = f1*g1 + f1*g2*cdbl(0,1);
        f2g1_f2g2[i] = f2*g1 + f2*g2*cdbl(0,1);
    }
    fft(f1g1_f1g2, true); fft(f2g1_f2g2, true);
    vector<ll> ret(n);
    for(int i=0;i<n;++i) {
        ll f1g1 = (ll)round(f1g1_f1g2[i].real());
        ll f1g2 = (ll)round(f1g1_f1g2[i].imag());
        ll f2g1 = (ll)round(f2g1_f2g2[i].real());
        ll f2g2 = (ll)round(f2g1_f2g2[i].imag());
        if(mod) {
            f1g1 %= mod; f1g2 %= mod; f2g1 %= mod; f2g2 %= mod;
        }
        ret[i] = (f1g1 << (2*shift)) + ((f1g2 + f2g1) << shift) + f2g2;
        if(mod) {
            ret[i] %= mod;
            ret[i] = (ret[i] + mod) % mod;
        }
    }
    return ret;
}

int main() {
    cin.tie(nullptr); ios::sync_with_stdio(false);
    cin >> N >> A >> B >> C;
    for(int i=1;i<=N;++i) cin >> L[i];
    for(int i=1;i<=N;++i) cin >> T[i];
    pre();
    for(int i=2;i<=N;++i) { // row
        ll tmp = (bino(N-1 + N-1-i, N-i) * ((modpow(A,N-i) * modpow(B,N-1)) % mod)) % mod;
        ans = (ans + (tmp * T[i]) % mod) % mod;
    }
    for(int i=2;i<=N;++i) { // col
        ll tmp = (bino(N-1 + N-1-i, N-i) * ((modpow(A,N-1) * modpow(B,N-i)) % mod)) % mod;
        ans = (ans + (tmp * L[i]) % mod) % mod;
    }
    vector<ll> A_inv(N-1), B_inv(N-1);
    for(int i=0;i<N-1;++i) {
        A_inv[i] = (modpow(A,i) * inv_fact[i]) % mod;
        B_inv[i] = (modpow(B,i) * inv_fact[i]) % mod;
    }
    vector<ll> res = multiply(A_inv, B_inv, mod);
    for(int i=0;i<2*N-3;++i) {
        ll tmp = (C*((fact[i] * res[i])%mod)) % mod;
        ans = (ans + tmp) % mod;
    }
    cout << ans << '\n';
    return 0;
}
반응형