본문 바로가기

Algorithm/Math

Miller–Rabin primality test && Pollard's rho algorithm

[BOJ 4149]의 코드로 설명을 대신하겠다. 왜 클래스의 이름을 거창하게 'number theory' 라고 지었는지는 비밀이다. :)

 

#include<bits/stdc++.h>
using namespace std;
#define fastio ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
typedef long long ll;
typedef __int128 i128;
typedef vector<ll> vec;
struct int_nt {
    ll mul(ll x,ll y,ll mod) { return (i128)x*y%mod; }
    ll gcd(ll x,ll y) { for(;y;x%=y,swap(x,y)); return x; }
    ll lcm(ll x,ll y) { return x/gcd(x,y)*y; }
    ll fpow(ll x,ll y,ll mod) {
        ll ret=1;
        while(y){
            if(y&1) ret=mul(ret,x,mod);
            x=mul(x,x,mod);
            y>>=1;
        }
        return ret;
    }
    bool miller_rabin(ll n,ll a) {
        ll k=n-1;
        while(1){
            ll t=fpow(a,k,n);
            if(t==n-1) return true;
            if(k&1) return (t==1||t==n-1);
            k>>=1;
        }
    }
    bool is_prime(ll n) {
        if(n<=1) return false;
        if(n<=10000){
            for(ll i=2;i*i<=n;i++)
                if(n%i==0) return false;
            return true;
        }
        vec prime={2,7,61};
        if(n>INT_MAX)
            prime={2,325,9375,28178,450775,9780504,1795265022};
        for(ll a: prime)
            if(!miller_rabin(n,a)) return false;
        return true;
    }
    ll pollard_rho(ll n) {
        ll x=rand()%(n-2)+2; ll y=x;
        ll c=rand()%(n-1)+1;
        while(1){
            x=(mul(x,x,n)+c)%n;
            y=(mul(y,y,n)+c)%n;
            y=(mul(y,y,n)+c)%n;
            ll d=gcd(abs(x-y),n);
            if(d==1) continue;
            if(!is_prime(d)) return pollard_rho(d);
            else return d;
        }
    }
    vec factorize(ll n)
    {
        vec p;
        while(!(n&1)){
            n>>=1;
            p.push_back(2);
        }
        while(n>1&&!is_prime(n)){
            ll d=pollard_rho(n);
            while(n%d==0){
                n/=d;
                p.push_back(d);
            }
        }
        if(n>1) p.push_back(n);
        sort(p.begin(),p.end());
        return p;
    }
} nt;
int main()
{
    ll n; cin >> n;
    for(auto p : nt.factorize(n)) cout << p << '\n';
}