본문 바로가기

Problem Solving/문제풀이

백준 5051번 피타고라스의 정리

반응형

N이 주어졌을 때, 아래를 만족하는 $(a,b,c)$쌍의 개수를 찾는 것이 목표입니다.
$$
0 \lt a,b,c \lt N, \quad a \le b, \quad a^2 + b^2 \equiv c^2 \pmod N
$$
N이 최대 500,000이기 때문에 나이브하게 구하면 시간초과가 납니다.

위에서 요구하고 있는 식을 잘 보면 결국 두 수의 합이 특정 수가 되는 경우의 수를 원하는 것입니다. 생성함수를 이용해 접근합시다.

다음과 같은 다항식을 생각합시다.
$$
f(x) = \sum_{i=0}^{N-1}a_ix^i, \quad a_i = \vert S_i \vert \quad where \quad S_i=\{x \vert x^2 \equiv i \pmod N, 0 \lt x \lt N\}
$$
이제 이렇게 구한 f(x)를 제곱해주면 두 수의 제곱의 합이 $c^2$인 모든 $(a,b)$쌍의 개수를 구할 수 있습니다.

여기서 $a \le b$라는 조건을 적용해야 되기 때문에 $a=b$인 경우는 따로 세주고, 2를 나눠주면 원하는 경우의 수를 얻을 수 있습니다.

#include<bits/stdc++.h>
using namespace std;
typedef complex<double> cdbl;

void fft(vector<cdbl> &a, bool inv) {
    int n = a.size();
    for(int i=1,j=0;i<n;++i) {
        int bit = n>>1;
        while(!((j^=bit) & bit)) bit >>= 1;
        if(i<j) swap(a[i],a[j]);
    }
    for(int i=1;i<n;i<<=1) {
        double x = inv? M_PI/i : -M_PI/i;
        cdbl w(cos(x),sin(x));
        for(int j=0;j<n;j+=i<<1) {
            cdbl p(1,0);
            for(int k=0;k<i;++k) {
                cdbl tmp = a[i+j+k] * p;
                a[i+j+k] = a[j+k] - tmp;
                a[j+k] += tmp;
                p *= w;
            }
        }
    }
    if(inv) for(int i=0;i<n;++i) a[i] /= n;
}

vector<cdbl> multiply(vector<cdbl> &f, vector<cdbl> &g) {
    vector<cdbl> pf(f.begin(), f.end()), pg(g.begin(),g.end());
    int n=1; while(n<max(pf.size(),pg.size())) n<<=1;
    n <<= 1;
    pf.resize(n); pg.resize(n);
    fft(pf,false); fft(pg,false);
    vector<cdbl> ret(n);
    for(int i=0;i<n;++i) ret[i] = pf[i] * pg[i];
    fft(ret,true);
    for(int i=0;i<n;++i) ret[i] = cdbl(round(ret[i].real()),0);
    return ret;
}

int main() {
    int N;
    cin >> N;
    vector<int> tmp(N);
    vector<int> sq(N);
    vector<cdbl> p(N);
    for(int i=1;i<N;++i) {
        tmp[(1LL*i*i)%N]++;
        sq[(2LL*i*i)%N]++;
    }
    for(int i=0;i<N;++i) p[i] = cdbl(tmp[i],0);
    vector<cdbl> mul = multiply(p,p);
    long long ans = 0;
    for(int i=1;i<N;++i) {
        int k = 1LL*i*i%N;
        int total = mul[k].real() + mul[N+k].real();
        int eq = sq[k];
        ans += (total-eq)/2 + eq;
    }
    cout << ans << '\n';
    return 0;
}
반응형

'Problem Solving > 문제풀이' 카테고리의 다른 글

백준 10531번 Golf Bot  (0) 2021.02.04
백준 17104번 골드바흐 파티션 2  (0) 2021.02.04
백준 15576번 큰 수 곱셈 (2)  (0) 2021.02.04
백준 11714번 Midpoint  (0) 2021.01.30
백준 5829번 Luxury River Cruise  (0) 2021.01.30