문제 요약
정점이 n개, 간선이 r개인 유향 그래프가 주어진다. n개의 정점 중에서 1번부터 b번 노드까지를 지사로 정한다. 이 b개의 지사를 s개의 그룹으로 나누고자 한다.
b+1번 노드를 본부라고 칭하자.
이제 각 지사는 본인이 속한 그룹에 메시지를 보내게 되는데 메시지를 보내는 방식은 다음과 같다.
$i$번 지사로부터 $j$번 지사로 메시지를 보낼 때 $i$번 지사로부터 본부까지 보낸 다음에 본부로부터 $j$번 지사로 메시지를 보내게 된다.
이런 식으로 메시지를 보낸다고 할 때, 각 지사를 s개의 그룹으로 잘 나눠서 보내지는 모든 메시지의 이동거리의 최소값을 구하는 것이 목표이다.
N은 최대 5000, r은 50000, b는 최대 N-1, s는 최대 b만큼 커진다.
풀이
$i$번 지사에서 본부까지의 최단 거리를 $dist_i$, 본부에서 $i$번 지사까지의 최단 거리를 $revdist_i$라고 하자. 그리고 이 둘을 합친 값을 $d_i$라고 하자.
만약에 $k$개의 지사가 묶여 있는 그룹에 $i$번지사를 추가로 넣는다고 하자. 그렇게 되면 $i$번 지사에서 본부로 가는 메시지가 $k-1$개, 본부에서 $i$번 지사로 오게되는 메시지가 $k-1$개 새로 생기게 된다. 즉, 그룹에 들어감으로 늘어나는 거리가 $(k-1)dist_i + (k-1)revdist_i$고 이는 $(k-1)d_i$이다.
이를 고려하면 $i$번 지사가 주는 영향은 본인이 들어가 있는 그룹의 크기와 관련이 큽니다. 따라서, $d_i$가 작은 값이라면 $i$번 지사는 가능한 큰 그룹에 넣는 것이 좋고 $d_i$가 큰 값이라면 $i$번 지사는 가능한 작은 그룹에 넣는 것이 좋습니다.
따라서, $d_i$의 오름차순으로 지사들을 정렬한 뒤에 연속저인 구간으로 그룹을 묶어주는 방법을 생각할 수 있습니다.
이제 이렇게 정렬한 순서대로 그룹을 묶어줄 때 다음과 같은 dp식을 세울 수 있습니다.
$$
\begin{gather}
dp(i,j) = \text{1부터 j번 지사까지 i개 그룹으로 묶었을 때의 최솟값} \\
dp(i,j) = \underset{k<j}{min}(dp(i-1, k)+C(k+1, j)) \\
C(i,j) = (j-i-1)(d_i+d_{i+1}+\cdots+d_j)
\end{gather}
$$
나이브하게 구하면 $O(sb^2)$으로 시간초과입니다.
$dp(i,j)$를 최적으로 만들어주는 $k$를 $opt_j$라고 합시다. $j < j'$일 때, $opt_j \le opt_{j'}$임을 알 수 있습니다.
이는 $d_i$가 오름차순으로 정렬되어 있기 때문에 $dp(i,j)$를 계산할 때, $d_j$는 $1...j$에서 $d_i$값이 제일 큽니다. $opt_j > opt_{j'}$라는 말은 제일 큰 $d$값을 더 큰 그룹에 넣었다는 뜻이 됩니다. 그런데도 비용이 작아졌다는 말이죠.
직관적인 설명이지만 수식에 직접 값을 넣어봐도 성립할 것입니다.
따라서, 분할정복 최적화를 적용할 수 있으며 $O(sblogb)$로 문제를 해결할 수 있습니다.
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll INF = 1e18;
vector<vector<pair<int, ll>>> G(5005), rev_G(5005);
vector<ll> D;
ll dp[5005][5005];
ll pre[5005];
int N, B, S, M;
void dijkstra(int src, vector<vector<pair<int,ll>>> &G, vector<ll> &dist) {
priority_queue<pair<ll,int>, vector<pair<ll,int>>, greater<pair<ll,int>>> pq;
dist[src] = 0;
pq.push(make_pair(dist[src], src));
while(!pq.empty()) {
int cur = pq.top().second;
ll d = pq.top().first;
pq.pop();
if(d > dist[cur]) continue;
for(auto nxt:G[cur]) {
int v = nxt.first;
ll w = nxt.second;
if(dist[v] > d + w) {
dist[v] = d + w;
pq.push(make_pair(dist[v], v));
}
}
}
}
void solve(int lev, int s, int e, int optl, int optr) {
if(s > e) return ;
int mid = (s+e) >> 1;
ll &ans = dp[lev][mid];
ans = dp[lev-1][optl] + (mid - optl - 1) * (pre[mid] - pre[optl]);
int opt = optl;
for(int i=optl;i<min(mid, optr);++i) {
ll val = dp[lev-1][i] + (mid - i - 1) * (pre[mid] - pre[i]);
if(ans > val) {
ans = val; opt = i;
}
}
solve(lev, s, mid-1, optl, opt+1);
solve(lev, mid+1, e, opt, optr);
}
int main() {
cin.tie(nullptr); ios::sync_with_stdio(false);
cin >> N >> B >> S >> M;
for(int i=0;i<M;++i) {
int u, v; ll w;
cin >> u >> v >> w;
G[u].push_back({v,w});
rev_G[v].push_back({u,w});
}
vector<ll> dist(N+1), rev_dist(N+1); // B+1 to all, all to B+1
fill(dist.begin(), dist.end(), INF);
fill(rev_dist.begin(), rev_dist.end(), INF);
dijkstra(B+1, G, dist);
dijkstra(B+1, rev_G, rev_dist);
D.resize(B+1, 0);
for(int i=1;i<=B;++i) D[i] = dist[i] + rev_dist[i];
sort(D.begin(), D.end());
for(int i=1;i<=B;++i) pre[i] = pre[i-1] + D[i];
for(int i=1;i<=B;++i) dp[1][i] = (i-1) * pre[i];
for(int i=2;i<=S;++i) solve(i, i, B, i-1, B);
cout << dp[S][B] << '\n';
return 0;
}
'Problem Solving > 문제풀이' 카테고리의 다른 글
백준 20180번 Two Buildings (0) | 2021.02.16 |
---|---|
백준 16138번 수강신청 (2) | 2021.02.16 |
백준 11001번 김치 (0) | 2021.02.16 |
백준 13262번 수열의 OR 점수 (0) | 2021.02.16 |
백준 14636번 Money for Nothing (0) | 2021.02.16 |