본문 바로가기

Problem Solving/문제풀이

백준 14176번 트리와 소수

반응형

데이터추가로 터졌습니다. 

문제 요약

크기 N인 노드가 주어진다. 두 노드를 랜덤으로 골랐을 때 두 노드의 거리가 소수일 확률을 뽑아야 한다.

N은 최대 10만이다.

풀이

모든 경로 중에서 경로의 길이가 소수인 경로의 개수를 찾는 문제다.

 

총 경로의 수는 $\frac{N(N-1)}{2}$이므로 소수인 경로의 개수를 찾고 이것으로 나눠주면 된다.

rooted tree에서 root와의 거리가 d인 정점들의 개수를 구한다고 하자. 거리가 d인 정점들의 개수를 구하는 것은 트리를 순회하는 것만으로도 충분하다.

 

이제 서로 다른 서브트리에 두 노드가 위치해서 root를 지나는 경로를 살펴보려고 하는데 이 경로들을 일일이 전부 확인하려면 $O(N^2)$이 걸린다.

 

그러나, 이는 두 배열에서 수를 하나씩 골라서 합이 특정 값이 되는지를 확인하는 것으로 Convolution으로 해결이 가능하다. 즉, $O(NlogN)$이 가능하고 실제로는 해당 서브트리의 사이즈를 $S$라고 하면 $O(SlogS)$정도가 된다.

 

이제 각 서브트리를 차례대로 순회하면서 해당 서브트리에 대해서 root와의 거리가 d인 정점의 개수를 저장한 배열을 $cur$라 하고 이전에 본 서브트리들에 대해서 위 정보를 기록한 배열을 $prev$라고 하면 $prev$와 $cur$의 Convolution을 FFT를 통해 구하고 길이가 소수인 경로들의 수를 합해주고 이 과정을 반복하면 하나의 rooted tree에 대해서 root를 지나느 경로들을 모두 고려해줄 수 있다.

 

이제 여기에 Centroid Decomposition만 끼얹으면 모든 경로를 고려해줄 수 있게 된다.

시간복잡도는 대략적으로 $O(Nlog^2N)$이 된다. 깊이 배열을 서브트리의 사이즈에 맞춰서 잡아주지 않으면 시간초과를 받을 수도 있다. 조심하자.

#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using cdbl = complex<double>;
const int MAXN = 100005;
int sub_sz[MAXN], parent[MAXN];
vector<vector<int>> G(MAXN);
ll N;
int mx_depth;
vector<int> prev_dep, cur_dep, primes;
bool vis[MAXN], not_prime[MAXN];
ll ans;

void fft(vector<cdbl> &a, bool inv) {
    int n = a.size();
    // bit reversal
    for(int i=1,j=0;i<n;++i) {
        int bit = n>>1;
        while(!((j^=bit) & bit)) bit >>=1;
        if(i<j) swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1) {
        double x = inv? M_PI / i : -M_PI / i;
        cdbl w = cdbl(cos(x),sin(x));
        for(int j=0;j<n;j+=i<<1) {
            cdbl p = cdbl(1,0);
            for(int k=0;k<i;++k) {
                cdbl tmp = a[i+j+k] * p;
                a[i+j+k] = a[j+k] - tmp;
                a[j+k] += tmp;
                p *= w;
            }
        }
    }
    if(inv) {
        for(int i=0;i<n;++i) a[i] /= n;
    }
}

vector<int> multiply(vector<int> &f, vector<int> &g) {
    vector<cdbl> pf(f.begin(), f.end()), pg(g.begin(), g.end());
    int n = 1; while (n < f.size() + g.size()) n <<= 1;
    pf.resize(n); pg.resize(n);
    fft(pf, false); fft(pg, false);
    for (int i = 0; i < n; ++i) pf[i] *= pg[i];
    fft(pf, true);
    vector<int> ret(n);
    for (int i = 0; i < n; ++i) {
        ret[i] = (int)round(pf[i].real());
    }
    return ret;
}

void sieve() {
    not_prime[0] = not_prime[1] = true;
    for(ll i=2;i<MAXN;++i) {
        if(!not_prime[i]) {
            primes.push_back(i);
            for(ll j=i*i;j<MAXN;j+=i) not_prime[j] = true;
        }
    }
}

int get_size(int cur, int par) {
    sub_sz[cur] = true;
    for(int nxt:G[cur]) {
        if(nxt == par || vis[nxt]) continue;
        sub_sz[cur] += get_size(nxt, cur);
    }
    return sub_sz[cur];
}

int get_cent(int cur, int par, int thr) {
    for(int nxt:G[cur]) {
        if(nxt == par || vis[nxt]) continue;
        if(sub_sz[nxt] > thr) return get_cent(nxt, cur, thr);
    }
    return cur;
}

void get_depth(int cur, int par, int d) {
    mx_depth = max(mx_depth, d);
    cur_dep[d]++;
    for(int nxt:G[cur]) {
        if(nxt == par || vis[nxt]) continue;
        get_depth(nxt, cur, d+1);
    }
} 

void solve(int cur) {
    int thr = get_size(cur, -1);
    int cent = get_cent(cur, -1, thr/2);
    get_size(cent, -1);
    vis[cent] = true;
    prev_dep.resize(1); prev_dep[0] = 1;
    for(int nxt:G[cent]) {
        if(vis[nxt]) continue;
        mx_depth = 0; cur_dep.clear(); cur_dep.resize(sub_sz[nxt]+1);
        get_depth(nxt, cent, 1);
        auto conv = multiply(prev_dep, cur_dep);
        for(int p:primes) {
            if(p >= conv.size() || p >= sub_sz[cent]) break;
            ans += conv[p];
        }
        if(prev_dep.size() <= mx_depth) prev_dep.resize(mx_depth+1);
        for(int i=0;i<=mx_depth;++i) prev_dep[i] += cur_dep[i];
    }
    for(int nxt:G[cent]) {
        if(vis[nxt]) continue;
        solve(nxt);
    }
}

int main() {
    cin.tie(nullptr); ios::sync_with_stdio(false);
    sieve();
    cin >> N;
    for(int i=1;i<N;++i) {
        int u, v; cin >> u >> v;
        G[u].push_back(v); G[v].push_back(u); 
    }
    prev_dep.reserve(MAXN); cur_dep.reserve(MAXN);
    solve(1);
    cout << fixed << setprecision(15) << (double)(2.0*ans)/(N*(N-1)) << '\n';
    return 0;
}
반응형

'Problem Solving > 문제풀이' 카테고리의 다른 글

백준 11012번 Egg  (0) 2021.02.16
백준 7469번 K번째 수  (0) 2021.02.16
백준 13431번 트리 문제  (0) 2021.02.16
백준 5820번 경주  (0) 2021.02.16
백준 13514번 트리와 쿼리 5  (0) 2021.02.16