문제 요약
길이 n인 배열이 주어진다. 여기서 최대 연속합 k개를 구해야 한다.
n은 최대 25만이고 k는 최대 1만이다.
풀이
여담으로 대회 중에 상수 큰 로그 두개 붙은 풀이 떠올렸다가 짜보지도 못하고 박살났다. ㅜㅜ
배열 a의 누적 합 배열을 $p$라고 하자. [i...j]의 연속합은 $p[j] - p[i-1]$이다.
오른쪽 끝이 고정 됐을 때 가장 큰 연속합은 $p[0...j-1]$중에서 가장 작은 값을 $p[j]$에서 빼준 값이 된다. 그러면 오른쪽 끝이 j인 연속합에서 두번째로 큰 값을 $p[0...j-1]$에서 두 번째로 작은 값을 빼준 값이 된다.
이 점을 고려해서 문제를 풀 수 있다. 먼저 모든 $i$번째 원소에 대해서 $i$번째 원소에서 끝나는 가장 큰 연속합을 뽑느다. 그리고 이것들을 관리하는 우선순위 큐를 만들자.
정확히는 {연속합의 값, 끝나는 위치, 해당 위치에서 끝나는 연속합 중 몇번째로 큰가} 이거를 관리하는 우선순위 큐가 된다.
그러면 제일 큰 연속합을 가지는 큐에서 꺼냈다고 하자. 이 연속합이 해당위치에서 끝나는 연속합 중에서 $i$번째 였다고 하자. 그러면 $p[0...j-1]$에서 $i+1$번째로 큰 원소를 찾은 다음에 $p[j]$에서 그 원소를 빼준 값을 다시 우선순위 큐에 넣어준다.
만약에 $i$가 $j$였다면 더 추가할 연속합이 없는 것이므로 추가하지 않고 끝난다. 위 과정을 K번 반복하면 최대 연속합 K개를 구할 수 있다.
누적합 배열 $p[0...j-1]$에서 $i$번째로 큰 원소를 찾아주는 것은 PST로 가능하다. 모르겠다면 이 링크를 참고하면 좋을 것 같다.
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
int N, M;
ll arr[250005], pre[250005];
int roots[250005];
int vsz, tsz;
struct Node {
int l, r, val;
Node() : l(0), r(0), val(0) {};
Node(int _l, int _r, int _val) : l(_l), r(_r), val(_val) {};
};
vector<Node> tree;
void init() {
int sz = tsz >> 1;
for (int i = 1; i < sz; ++i) {
tree[i].l = i << 1; tree[i].r = i << 1 | 1;
}
}
void update(int node, int s, int e, int val, int idx) {
tree[node].val += val;
if (s != e) {
int mid = (s + e) >> 1;
int n1 = tree[node].l, n2 = tree[node].r;
if (idx <= mid) {
tree[node].l = tree.size();
tree.push_back(tree[n1]);
update(tree[node].l, s, mid, val, idx);
}
else {
tree[node].r = tree.size();
tree.push_back(tree[n2]);
update(tree[node].r, mid + 1, e, val, idx);
}
}
}
int getKth(int s, int e, int k) {
int l = 1, r = vsz;
while (l != r) {
int mid = (l + r) >> 1;
int lsz = tree[tree[e].l].val - tree[tree[s].l].val;
if (lsz >= k) {
s = tree[s].l; e = tree[e].l;
r = mid;
}
else {
s = tree[s].r; e = tree[e].r;
k -= lsz;
l = mid + 1;
}
}
assert(l == r);
return l;
}
struct cmp {
bool operator() (pair<pair<ll, int>, int>& a, pair<pair<ll, int>, int>& b) {
if (a.first == b.first) return a.second < b.second;
if (a.first.first == b.first.first) return a.first.second > b.first.second;
return a.first.first < b.first.first;
}
};
int main() {
cin.tie(nullptr); ios::sync_with_stdio(false);
cin >> N >> M;
vector<ll> v;
for (int i = 1; i <= N; ++i) {
cin >> arr[i]; pre[i] = pre[i - 1] + arr[i]; v.push_back(pre[i]);
}
v.push_back(0);
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
vsz = v.size(); tsz = 1; while (tsz < vsz) tsz <<= 1; tsz <<= 1;
tree.resize(tsz);
init();
roots[0] = 1;
for (int i = 0; i <= N; ++i) {
int idx = lower_bound(v.begin(), v.end(), pre[i]) - v.begin() + 1;
roots[i + 1] = tree.size(); tree.push_back(tree[roots[i]]);
update(roots[i + 1], 1, vsz, 1, idx);
}
priority_queue<pair<pair<ll, int>, int>, vector<pair<pair<ll, int>, int>>, cmp> pq;
for (int i = 1; i <= N; ++i) {
int idx = getKth(roots[0], roots[i], 1) - 1;
pq.push(make_pair(make_pair(pre[i] - v[idx], 1), i));
}
for (int i = 0; i < M; ++i) {
auto p = pq.top();
pq.pop();
int id = p.second;
ll val = p.first.first;
int k = p.first.second;
cout << val << " \n"[i == M - 1];
if (k >= id) continue;
int idx = getKth(roots[0], roots[id], k + 1) - 1;
pq.push(make_pair(make_pair(pre[id] - v[idx], k + 1), id));
}
return 0;
}
'Problem Solving > 문제풀이' 카테고리의 다른 글
백준 11932번 트리와 K번째 수 (0) | 2021.02.17 |
---|---|
백준 13513번 트리와 쿼리 4 (0) | 2021.02.17 |
백준 13538번 XOR 쿼리 (0) | 2021.02.16 |
백준 11012번 Egg (0) | 2021.02.16 |
백준 7469번 K번째 수 (0) | 2021.02.16 |