문제 요약
$m \times n$ 행렬 $T$와 $m \times l$행렬 $P$, 그리고 숫자 $W$가 주어집니다.
$W_k=\sum_{i=1}^{m}\sum_{j=1}^l T(i,j+k)P(i,j)$라고 할 때, $W_k > W$인 횟수를 출력하는 문제입니다.
풀이
문제의 그림처럼 $m \times n$ 행렬에서 $m \times l$ 행렬을 슬라이딩 시키면서 그 위치에서 두 행렬의 pointwise multiplication sum을 구합니다. 그 값이 $W$를 넘는 횟수를 구하는 문제입니다.
나이브하게 구하면 $O(nml)$로 시간초과를 받습니다.
이 문제는 FFT를 통해 Convolution을 수행함으로 빠르게 가능합니다.
Convolution 연산은 하나의 수열은 가만히 놔두고 다른 수열이 뒤집어진 채로 슬라이딩 하면서 pointwise multiplication sum을 구하는 형태를 하고 있습니다. 이는 이 문제에서 원하는 것과 동일합니다
수열이 뒤집힌 채로 슬라이딩을 진행하기 때문에 행렬 P를 좌우로 뒤집은 행렬을 P'라고 합시다.
그러면 T와 P'의 행 별로 Convolution을 취하면 각 행에서 $W_k$를 계산하는 데에 사용되는 pointwise multiplication sum을 구할 수 있습니다. 이에 걸리는 시간 복잡도는 $O(m(n+l)log(n+l))$입니다.
그리고 $W_k$를 하나 구하는 데에는 각 행의 원소를 더해주는 데에 $O(m)$이 걸리고 $W_k$는 $O(n)$개 만큼 구해야되기 때문에 $O(nm)$이 걸립니다. 이걸로 문제를 해결할 수 있습니다.
#include<iostream>
#include<vector>
#include<complex>
#include<cmath>
#include<algorithm>
using namespace std;
using ll = long long;
typedef complex<double> cdbl;
int N, M, L;
ll W;
const double PI = acos(-1);
vector<vector<ll>> S(105), T(105), conv(105);
void fft(vector<cdbl> &a, bool inv) {
int n = a.size();
// bit reversal
for (int i = 1, j = 0; i<n; ++i) {
int bit = n >> 1;
while (!((j ^= bit) & bit)) bit >>= 1;
if (i<j) swap(a[i], a[j]);
}
for (int i = 1; i<n; i <<= 1) {
double x = inv ? PI / i : -PI / i;
cdbl w = cdbl(cos(x), sin(x));
for (int j = 0; j<n; j += i << 1) {
cdbl p = cdbl(1, 0);
for (int k = 0; k<i; ++k) {
cdbl tmp = a[i + j + k] * p;
a[i + j + k] = a[j + k] - tmp;
a[j + k] += tmp;
p *= w;
}
}
}
if (inv) {
for (int i = 0; i<n; ++i) a[i] /= n;
}
}
// h = fg
vector<ll> multiply(vector<ll> &f, vector<ll> &g) {
vector<cdbl> pf(f.begin(), f.end()), pg(g.begin(), g.end());
int n = 1; while (n < f.size() + g.size()) n <<= 1;
pf.resize(n); pg.resize(n);
fft(pf, false); fft(pg, false);
for (int i = 0; i < n; ++i) pf[i] *= pg[i];
fft(pf, true);
vector<ll> ret(n);
for (int i = 0; i < n; ++i) {
ret[i] = (ll)round(pf[i].real());
}
return ret;
}
int main() {
int N, M;
cin >> N >> L >> M >> W;
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
int x; cin >> x;
S[i].push_back(x);
}
}
for (int i = 0; i < M; ++i) {
for (int j = 0; j < L; ++j) {
int x; cin >> x;
T[i].push_back(x);
}
reverse(T[i].begin(), T[i].end());
}
for (int i = 0; i < M; ++i) {
conv[i] = multiply(S[i], T[i]);
}
int ans = 0;
for (int i = L - 1; i < N; ++i) {
ll sum = 0;
for (int j = 0; j < M; ++j) {
sum += conv[j][i];
}
if (sum > W) ++ans;
}
cout << ans << '\n';
return 0;
}
'Problem Solving > 문제풀이' 카테고리의 다른 글
백준 20176번 Needle (0) | 2021.02.15 |
---|---|
백준 17134번 르모앙의 추측 (0) | 2021.02.15 |
백준 10793번 Tile Cutting (2) | 2021.02.15 |
백준 10531번 Golf Bot (0) | 2021.02.04 |
백준 17104번 골드바흐 파티션 2 (0) | 2021.02.04 |