문제 요약
처음에 길이 N인 배열 A가 주어진다. 이 배열에 대해서 다음 쿼리를 처리해야 한다
- p v가 주어지면 $A_p$앞에 v를 추가합니다.
- p가 주어지면 $A_p$를 제거합니다.
- p v가 주어지면 $A_p$를 v로 바꿉니다.
- l r k가 주어지면 $l \le i \le r$에 대해서 $\sum A_i(k-l+1)^k \mod 2^{32}$를 구합니다.
N과 쿼리 수는 둘 다 최대 10만입니다.
풀이
1, 2, 3번 쿼리는 스플레이 트리를 이용하면 쉽게 처리가 가능합니다.
4번 쿼리에 집중해보죠
스플레이 트리로 배열을 다룰 때, 하나의 서브트리는 하나의 연속 구간을 나타냅니다.
k가 0부터 10까지이므로 서브트리의 루트는 해당 구간에 대해서 모든 k에 대한 답을 가지고 있다고 합시다.
그러면 이런 서브트리 두개를 어떻게 잘 합쳐줄지를 생각해봅시다.
[l, r]을 담당하는 서브트리가 있다고 하죠. 해당 서브트리의 루트는 $t$번재 노드라고 하면 왼족 서브트리는[l, t-1]을 나타내고, 오른쪽 서브트리는 [t+1, r]을 나타내는 상황이 됩니다.
왼쪽 서브트리와 오른쪽 서브트리에서는 그 서브트리들이 나타내는 구간에 대한 쿼리의 답을 들고 있다고 합시다. k가 정해졌을 때의 답을 각각 $ans_{l,k}, ans_{r,k}$이라고 부르겠습니다.
k일때, [l, r]에 대한 답 $ans_k$는 $ans_k = ans_{l,k} + a[t] (t-l+1)^k + \sum_{i=t+1}^k a[i] (i-l+1)^k$입니다.
$ans_{l,k}$은 건드릴게 없고 $ans_{r,k}$을 잘 바꿔서 위 꼴을 만들어줘야 합니다.
$ans_{r,k} = \sum_{i=t+1}^r a[i] (i-t)^k$입니다. 여기서 $x=i-t, y=t-l+1$이라고 합시다.
그러면 $ans_k$의 마지막 항이 이렇게 바뀝니다.
$$
\sum_{i=t+1}^r a[i] (i-l+1)^k = \sum_{i=t+1}^r a[i] (x+y)^k
$$
$(x+y)^k = \sum_{j=0}^kx^jy^{k-j}\binom{k}{j}$를 위 식에 집어 넣습니다.
$$
\begin{aligned}
\sum_{i=t+1}^r a[i] (i-l+1)^k &= \sum_{i=t+1}^r a[i] (x+y)^k \\
&=\sum_{i=t+1}^r a[i] ( \sum_{j=0}^k x^j y^{k-j} \binom{k}{j} ) \\
&=\sum_{j=0}^k \binom{k}{j} j^{k-j} (\sum_{i=t+1}^r a[i]) x^j \\
&= \sum_{j=0}^k \binom{k}{j} j^{k-j} ans_{r,j}
\end{aligned}
$$
k는 0부터 10까지 이므로 이항계수만 전처리 해두면 $ans_k$는 $O(k)$로 구하는 것이 가능합니다.
따라서, 노드 하나의 업데이트를 진행하는 것에 $O(k^2)$가 걸립니다.
그러면 스플레이 트리의 기본 시간복잡도인 amortied $O(logN)$에 $k^2$이 곱해진 $O(k^2logN)$이 되고 총 시간복잡도는 $O(Mk^2logN)$이 됩니다. 사실 이게 안 돌아갈 거 같았는데 돌아갑니다.
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using uint = unsigned int;
uint comb[11][11];
struct Node {
Node *l, *r, *p;
uint sz;
uint val;
bool inv;
// declare extra variables
uint ans[11];
bool dummy;
Node() : l(nullptr), r(nullptr), p(nullptr) {};
Node(uint _val) {
l = r = p = nullptr;
sz = 1;
val = _val;
inv = false;
dummy = false;
// init extra variables
for(int i=0;i<=10;++i) ans[i] = val;
}
} *root;
// pull from children(size, sum etc)
void pull(Node *cur) {
// pulling process
cur->sz = 1;
if(cur->l) cur->sz += cur->l->sz;
if(cur->r) cur->sz += cur->r->sz;
uint lsz = cur->l? cur->l->sz:0;
uint pw[11];
pw[0] = 1;
for(int i=1;i<=10;++i) pw[i] = pw[i-1] * (lsz + 1);
for(int i=0;i<=10;++i) cur->ans[i] = pw[i] * cur->val;
for(int i=0;i<=10;++i) {
if(cur->l) cur->ans[i] += cur->l->ans[i];
if(cur->r) {
for(int j=0;j<=i;++j) cur->ans[i] += cur->r->ans[j] * pw[i-j] * comb[i][j];
}
}
}
void reverse(Node *cur) {
// reverse process
swap(cur->l, cur->r);
}
// push into children(lazy, inv etc)
void push(Node *cur) {
// pushing process
if(cur->inv) {
reverse(cur);
if(cur->l) cur->l->inv = !cur->l->inv;
if(cur->r) cur->r->inv = !cur->r->inv;
cur->inv = false;
}
}
void rotate(Node *cur) {
Node* p = cur->p;
if(!p) return ; // cur is root
push(p);
push(cur);
Node *tmp;
if(p->l == cur) {
tmp = cur->r;
p->l = cur->r;
cur->r = p;
}
else {
tmp = cur->l;
p->r = cur->l;
cur->l = p;
}
Node *pp = p->p;
cur->p = pp;
p->p = cur;
if(tmp) tmp->p = p;
if(cur->p) {
if(pp->l == p) pp->l = cur;
else if(pp->r == p) pp->r = cur;
}
else root = cur;
pull(p);
pull(cur);
}
void splay(Node *cur) {
while(cur->p) {
Node *p = cur->p;
Node *pp = p->p;
if(pp) {
if((pp->l == p) == (p->l == cur)) rotate(p);
else rotate(cur);
}
rotate(cur);
}
}
// split tree first < cur, second > cur
pair<Node*, Node*> split(Node *cur) {
splay(cur);
Node *left = cur;
Node *right = cur->r;
if(left) left->p = nullptr;
if(right) right->p = nullptr;
return {left, right};
}
Node* merge(Node *left, Node *right) {
if(left == nullptr) return right;
if(right == nullptr) return left;
Node *p = left;
while(p->r) p = p->r;
push(p);
splay(p);
p->r = right;
right->p = p;
pull(p);
return p;
}
Node* kth(Node *cur, int k) {
Node *p = cur;
push(p);
while(1) {
while(p->l && p->l->sz > k) {
p = p->l;
push(p);
}
if(p->l) k -= p->l->sz;
if(!k) break;
else --k;
p = p->r;
push(p);
}
splay(p);
return p;
}
// p->r->l represents [l,r]
Node* interval(Node *cur, int l, int r) {
Node *p = kth(cur, l-1);
Node *right = p->r;
right->p = nullptr;
right = kth(right, r-l+1);
right->p = p;
p->r = right;
root = p;
pull(p);
return p;
}
int N, M;
uint a[100005];
void init() {
comb[0][0] = 1;
for(int i=1;i<=10;++i) {
comb[i][0] = comb[i][i] = 1;
for(int j=1;j<i;++j) comb[i][j] = comb[i-1][j-1] + comb[i-1][j];
}
Node *p = root = new Node(0);
p->dummy = true;
for(int i=1;i<=N;++i) {
p->r = new Node(a[i]);
p->r->p = p;
p = p->r;
}
p->r = new Node(0);
p->r->p = p;
p = p->r;
p->dummy = true;
while(p) {
pull(p);
p = p->p;
}
}
Node* query1(Node *cur, int pos, uint val) {
Node *p = kth(root, pos)->r;
while(p->l) p = p->l;
p->l = new Node(val);
p->l->p = p;
splay(p->l);
return root;
}
Node* query2(Node *cur, int pos) {
Node *p = kth(root, pos);
pair<Node*, Node*> pp = split(p);
if(pp.first) push(pp.first);
if(pp.second) push(pp.second);
root = merge(pp.first->l, pp.second);
return root;
}
Node *query3(Node *cur, int pos, int val) {
Node *p = kth(root, pos);
p->val = val;
pull(p);
return root;
}
uint query4(Node *cur, uint l, uint r, uint k) {
Node *p = interval(root, l, r);
return p->r->l->ans[k];
}
void print(Node *cur) {
if(cur->l) print(cur->l);
if(!cur->dummy) cout << cur->val << ' ';
if(cur->r) print(cur->r);
}
int main() {
cin.tie(nullptr); ios::sync_with_stdio(false);
cin >> N;
for(int i=1;i<=N;++i) cin >> a[i];
init();
cin >> M;
for(int i=0;i<M;++i) {
string s;
uint q, p, v, l, r, k;
cin >> q;
if(q == 1) {
cin >> p >> v;
query1(root, p, v);
}
else if(q == 2) {
cin >> p;
query2(root, p+1);
}
else if(q == 3) {
cin >> p >> v;
query3(root, p+1, v);
}
else if(q == 4) {
cin >> l >> r >> k;
++l; ++r;
cout << query4(root, l, r, k) << '\n';
}
// print(root); cout << '\n';
}
return 0;
}
'Problem Solving > 문제풀이' 카테고리의 다른 글
백준 16586번 Linked List (0) | 2021.02.17 |
---|---|
백준 17607번 수열과 쿼리 31 (0) | 2021.02.17 |
백준 3444번 Robotic Sort (0) | 2021.02.17 |
백준 13159번 배열 (0) | 2021.02.17 |
백준 16977번 히스토그램에서 가장 큰 직사각형과 쿼리 (0) | 2021.02.17 |