문제 요약
크기가 $N$인 트리가 주어진다. 이러한 트리에서 간선들에 방향을 부여하려고 한다.
이 때, $(u, v)$ 쌍이 $M$개 주어지는데 이는 간선들에 방향을 부여했을 때 $u \rightarrow v$인 경로가 존재하거나 $u \leftarrow v$인 경로가 존재해야 한다는 뜻이다.
문제에서 원하는 것은 $M$개의 조건을 만족하면서 간선에 방향을 부여하는 방법의 수를 구하는 것이다.
$ 1 \le N, M \le 3 \cdot 10^5 $
풀이
$(u, v)$가 주어지면 만족해야 하는 조건은 세가지가 생기는 것으로 볼 수 있다. $LCA(u, v)=l$이라고 하자.
- $u \rightarrow l$의 간선들은 방향이 같아야 한다.
- $l \rightarrow v$의 간선들은 방향이 같아야 한다.
- $ u \rightarrow l$의 간선들과 $ l \rightarrow v $의 간선들의 방향은 달라야 한다.
이 조건들을 전부 간선별로 저장하면 같아야 되는 간선들의 집합, 서로 달라야 되는 간선들의 집합이 나오고 상관 없는 집합이 나올 것이다. 2-SAT이다.
그런데 2-SAT을 돌리고 싶어도 위 조건들에 따라서 모든 간선을 분류하는 데에는 $O(NM)$이 걸린다. 조금 다르게 생각하자.
위 간선들의 집합들을 생각해보면 해당 집합의 간선들 중 어느 하나라도 방향이 결정되면 그 집합에 속하는 간선들은 전부 결정된다. 따라서, 하나의 집합당 2가지 경우가 있다는 것이다. 그러면 집합의 갯수 $cnt$만 구하면 답은 $2^{cnt}$가 된다.
집합의 갯수만 구하면 된다는 것을 알았다. 그러나 여전히 나이브하게 구하면 $O(NM)$이 걸린다. 조건을 잘 살펴보면 방향이 같아야 되는 간선들의 집합과 달라야하는 집합들로 나뉜다.
그러면 각 간선을 두 개로 나누자. 하나는 그 간선이 부모에서 자식으로 내려가는 방향을 뜻하는 것이고, 다른 하나는 자식에서 부모로 올라가는 방향을 가진다는 뜻이다. 그러면 $(u, v)$가 들어왔을 때 $u$에서 $l$로 올라가면서 같은 방향을 뜻하는 간선들은 서로 묶어주고 또 $v$에서 $l$로 올라가면서 같은 방향을 뜻하는 간선들을 서로 묶어주자.
그리고 $u \rightarrow l$과 $l \rightarrow v$의 간선들이 서로 다른 방향을 가져야 한다는 것은 $u$의 정방향 간선과 $v$의 역방향 간선을 묶어주고, $u$의 역방향 간선과 $v$의 정방향 간선을 묶어주면 된다. 이렇게 하고 나서 모든 간선에 대해서 서로 다른 방향의 간선이 같이 묶여 있는지 확인해주면 된다. 만약 그런 간선이 있다면 조건을 만족하는게 불가능한 상황이다.
그러한 간선이 없다면 이제 집합의 갯수를 세주면 된다.
여기서 묶어주는 것은 Union Find로 수행하면 빠르게 가능하다. 따라서 총 시간복잡도는 $O((N + M)\log N)$이 된다.
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9+7;
int parents[600005];
int sp_table[20][300005];
int depth[300005];
int N, M;
vector<vector<int>> G(300005);
int Find(int a) {
if(parents[a] < 0) return a;
return parents[a] = Find(parents[a]);
}
void Union(int a, int b) {
int pa = Find(a);
int pb = Find(b);
if(pa == pb) return ;
parents[pa] = pb;
}
void dfs(int cur, int par, int d) {
depth[cur] = d;
for(int nxt : G[cur]) {
if(nxt == par) continue;
dfs(nxt, cur, d+1);
sp_table[0][nxt] = cur;
}
}
int LCA(int u, int v) {
if(depth[u] > depth[v]) swap(u, v);
int diff = depth[v] - depth[u];
int idx = 0;
while(diff) {
if(diff & 1) v = sp_table[idx][v];
diff >>= 1; idx++;
}
if(u == v) return u;
for(int i=19;i>=0;--i) {
if(sp_table[i][u] != sp_table[i][v]) {
u = sp_table[i][u];
v = sp_table[i][v];
}
}
return sp_table[0][u];
}
int main() {
cin.tie(nullptr); ios::sync_with_stdio(false);
cin >> N >> M;
for(int i=1;i<N;++i) {
int u, v; cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0, 0);
for(int i=1;i<20;++i) for(int j=1;j<=N;++j) sp_table[i][j] = sp_table[i-1][sp_table[i-1][j]];
memset(parents, -1, sizeof(parents));
vector<pair<int, int>> after;
for(int i=0;i<M;++i) {
int u, v; cin >> u >> v;
int lca = LCA(u, v);
int cur = u;
while(depth[sp_table[0][cur]] > depth[lca]) {
int p = sp_table[0][cur];
Union(cur, p);
Union(cur+N, p+N);
cur = Find(p);
}
cur = v;
while(depth[sp_table[0][cur]] > depth[lca]) {
int p = sp_table[0][cur];
Union(cur, p);
Union(cur+N, p+N);
cur = Find(p);
}
if(u != lca && v != lca) {
after.emplace_back(u, v);
}
}
for(pair<int, int> p : after) {
int u = p.first;
int v = p.second;
Union(u+N, v);
Union(v+N, u);
}
for(int i=1;i<=N;++i) {
if(Find(i) == Find(i+N)) {
cout << 0 << '\n';
return 0;
}
}
int cnt = -2;
for(int i=1;i<=2*N;++i) cnt += parents[i] < 0;
int ans = 1;
for(int i=0;i<cnt/2;++i) ans = ans * 2 % mod;
cout << ans << '\n';
return 0;
}
'Problem Solving > 문제풀이' 카테고리의 다른 글
백준 15339번 Counting Cycles (0) | 2021.05.14 |
---|---|
백준 18214번 Reordering the Documents (0) | 2021.05.10 |
백준 20349번 Xortest Path (0) | 2021.04.22 |
백준 20226번 Luggage (0) | 2021.04.20 |
백준 18216번 Ambiguous Encoding (0) | 2021.04.19 |