码迷,mamicode.com
首页 > 其他好文 > 详细

跟多项式运算相关代码

时间:2021-02-18 13:16:23      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:lang   c++   idf   cout   orm   amp   pen   pac   cer   

  1. 共轭优化 FFT,P3803 多项式乘法

  1. NTT,P3803 多项式乘法
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef double db;
#define IL inline

const int MOD = 998244353;
const int N = 2097152;
const int G = 3;

IL int ksm(int a,int b,int m) {
	int res = 1;
	while(b) {
		if(b&1) res = 1LL * res * a % m;
		a = 1LL * a * a % m;
		b >>= 1;
	}
	return res;
}
IL int inv(int x) { return ksm(x,MOD-2,MOD);}

int rev[N];
int eps[N], ieps[N];
IL void initeps() {
	int g = ksm(G, (MOD-1) / N, MOD), ig = inv(g);
	ieps[0] = eps[0] = 1;
	for(int i=1;i<N;i++) eps[i] = 1LL * eps[i-1] * g % MOD;
	for(int i=1;i<N;i++) ieps[i] = 1LL * ieps[i-1] * ig % MOD;
}
IL void cal_rev(int degA, int degB, int& lim, int& p) {
	lim = 1; p = 0;
	while(lim <= (degA+degB)) {lim <<= 1; ++p;}
	for(int i=0;i<lim;i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(p-1));
}
IL void trans(int *x, int* w, int n) {
	for(int i=0; i<n; i++) if(i < rev[i]) swap(x[i],x[rev[i]]);
	for(int i=2;i<=n;i<<=1) {
		int l = i>>1, d = N / i;
		for(int j=0;j<n;j+=i) {
			for(int k=0;k<l;k++) {
				int t = 1LL * x[j+k+l] * w[d*k] % MOD;
				x[j+k+l] = (1LL * x[j+k] - t + MOD) % MOD;
				x[j+k] = (x[j+k] + t) % MOD;
			}
		}
	}
}
IL void dft(int* x, int n) { trans(x,eps,n);}
IL void idft(int* x, int n) { 
	trans(x,ieps,n); 
	int in = inv(n);
	for(int i=0;i<n;i++) x[i] = 1LL * x[i] * in % MOD;
}

int ntt_a[N], ntt_b[N];
IL int mul_normal(int *C, int *A, int *B, int degA, int degB) {
	int lim, p;
	cal_rev(degA, degB, lim, p); // if length is the same, u can write it in the main. 
	for(int i=0;i<=degA;i++) ntt_a[i] = A[i];
        for(int i=degA+1;i<lim;i++) ntt_a[i] = 0;
	for(int i=0;i<=degB;i++) ntt_b[i] = B[i];
        for(int i=degB+1;i<lim;i++) ntt_b[i] = 0;
	dft(ntt_a, lim); dft(ntt_b, lim);
	for(int i=0;i<lim;i++) ntt_a[i] = 1LL * ntt_a[i] * ntt_b[i] % MOD;
	idft(ntt_a, lim);
	for(int i=0;i<=degA+degB;i++) C[i] = ntt_a[i];
	return degA + degB; // return length of the poly.
}

int f[N], g[N], ans[N];

int main() {
	initeps();
	int n,m; scanf("%d%d",&n,&m);
	for(int i=0;i<=n;i++) scanf("%d",&f[i]);
	for(int i=0;i<=m;i++) scanf("%d",&g[i]);
	mul_normal(ans,f,g,n,m);
	for(int i=0;i<=n+m;i++) printf("%d%c",ans[i]," \n"[i==n+m]);
	return 0;
} 
  1. 任意模数多项式乘法,使用拆系数共轭优化 FFT 通过。P4245
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef double db;
typedef long double ld;
#define IL inline
#define dbg1(x) cerr << #x << " = " << x << ", "
#define dbg2(x) cerr << #x << " = " << x << endl

template<typename Tp> IL void read(Tp& x) {
    x=0; int f=1; char ch=getchar();
    while(!isdigit(ch)) {if(ch == ‘-‘) f=-1; ch=getchar();}
    while(isdigit(ch)) { x = x*10+ch-‘0‘; ch=getchar();}
    x *= f;
}
int buf[22];
template<typename Tp> IL void write(Tp x) {
    int p = 0;
    if(x < 0) {putchar(‘-‘); x=-x;}
    if(x == 0) buf[++p] = 0;
    else while(x) {
        buf[++p] = x % 10;
        x /= 10;
    }
    for(int i=p;i;i--) putchar(‘0‘ + buf[i]);
}

struct cp {
    db x, y;
    cp(db x=0.0, db y=0.0):x(x), y(y) {}
    IL cp operator + (const cp& o) const { return cp(x+o.x, y+o.y);}
    IL cp operator - (const cp& o) const { return cp(x-o.x, y-o.y);}
    IL cp operator * (const cp& o) const { return cp(x*o.x-y*o.y, x*o.y+y*o.x);}
    IL cp operator * (const db& p) const { return cp(x*p, y*p);}
    IL cp operator / (const db& p) const { return cp(x/p, y/p);}
    IL cp operator ! () const { return cp(x, -y);}
};
IL cp polar(const db& rho, const db& theta) {return cp(rho*cos(theta), rho*sin(theta));}

const int N = 131072;
const db PI = acos(-1.0);

int rev[N];
cp eps[N], ieps[N];
IL void initeps() {
    for(int i=0;i<N;i++) eps[i] = polar(1.0, 2*PI*i/N);
    ieps[0] = eps[0] = 1;
    for(int i=1;i<N;i++) ieps[i] = eps[N-i];
}
IL void cal_rev(int degA, int degB, int& lim, int &p) {
    lim = 1; p = 0;
    while(lim <=  ((degA+degB) >> 1)) {lim <<= 1; p++;}
    for(int i=0;i<lim;i++) rev[i] = (rev[i>>1]>>1) | ((i&1) << (p-1));
}
IL void trans(cp *x, cp *w, int n) {
    for(int i=0;i<n;i++) if(i < rev[i]) swap(x[i], x[rev[i]]);
    for(int i=2;i<=n;i<<=1) {
        int l = i >> 1, d = N / i;
        for(int j=0;j<n;j+=i) {
            for(int k=0;k<l;k++) {
                cp t = x[j+k+l] * w[d*k];
                x[j+k+l] = x[j+k] - t;
                x[j+k] = x[j+k] + t;
            }
        }
    }
}
IL void dft(cp *x, int n) {trans(x, eps, n);}
IL void idft(cp *x, int n) {trans(x, ieps, n); for(int i=0;i<n;i++) x[i] = x[i] / n;}

int n, m, P;

int f[N], g[N], ans[N<<1];
cp f1[N], f2[N], g1[N], g2[N];
cp h1[N], h2[N], h3[N];

IL int solve(int *C, int *A, int *B, int degA, int degB, int mod) {
    int p = 31624; // sqrt(1e9 + 7) = 31622.77
    for(int i=0;i<=degA;i++) { 
        (i & 1 ? f1[i>>1].y : f1[i>>1].x) = A[i] / p; 
        (i & 1 ? f2[i>>1].y : f2[i>>1].x) = A[i] % p;
    }
    for(int i=0;i<=degB;i++) { 
        (i & 1 ? g1[i>>1].y : g1[i>>1].x) = B[i] / p; 
        (i & 1 ? g2[i>>1].y : g2[i>>1].x) = B[i] % p;
    }
    int lim, lglim;
    cal_rev(degA, degB, lim, lglim);
    for(int i=degA+1;i<lim;i++) {
        (i & 1 ? f1[i>>1].y : f1[i>>1].x) = 0; 
        (i & 1 ? f2[i>>1].y : f2[i>>1].x) = 0;
    }
    for(int i=degB+1;i<lim;i++) {
        (i & 1 ? g1[i>>1].y : g1[i>>1].x) = 0; 
        (i & 1 ? g2[i>>1].y : g2[i>>1].x) = 0;
    }
    dft(f1, lim); dft(f2, lim); dft(g1, lim); dft(g2, lim);
    // X = f1 * p + f2
    // Y = g1 * p + g2
    // Z = (f1*g1)p^2 + (f1*g2+f2*g1)p + f2*g2
    int d = N / lim;
    for(int i=0;i<lim;i++) {
        int j = (lim-1) & (lim-i);
        h1[i] = (f1[i] * g1[i] * 4 - (f1[i]-!f1[j]) * 
                (g1[i]-!g1[j]) * (eps[d*i] + cp(1,0))) * 0.25;
        h2[i] = (f1[i] * g2[i] * 4 - (f1[i]-!f1[j]) * 
                (g2[i]-!g2[j]) * (eps[d*i] + cp(1,0))) * 0.25
               +(f2[i] * g1[i] * 4 - (f2[i]-!f2[j]) *
                (g1[i]-!g1[j]) * (eps[d*i] + cp(1,0))) * 0.25;
        h3[i] = (f2[i] * g2[i] * 4 - (f2[i]-!f2[j]) *
                (g2[i]-!g2[j]) * (eps[d*i] + cp(1,0))) * 0.25;
    }
    idft(h1, lim); idft(h2, lim); idft(h3, lim);
    for(int i=0;i<=degA+degB;i++) {
        ll h1v = ((i&1) ? h1[i>>1].y : h1[i>>1].x) + 0.5;
        ll h2v = ((i&1) ? h2[i>>1].y : h2[i>>1].x) + 0.5;
        ll h3v = ((i&1) ? h3[i>>1].y : h3[i>>1].x) + 0.5;
        h1v %= mod; h2v %= mod; h3v %= mod;
        // dbg1(i); dbg1(h1v); dbg1(h2v); dbg2(h3v);
        C[i] = (h1v*p%mod*p%mod + h2v*p%mod + h3v) % mod;
    }
    return degA + degB;
}

int main() {
    initeps();
    read(n); read(m); read(P);
    for(int i=0;i<=n;i++) read(f[i]);
    for(int i=0;i<=m;i++) read(g[i]);
    solve(ans, f, g, n, m, P);
    for(int i=0;i<=n+m;i++) {write(ans[i]); putchar(" \n"[i==n+m]);}
    return 0;
}
  1. 多项式乘法逆 P4238
#include <bits/stdc++.h>
using namespace std;

#define IL inline
#define ri register int 
#define dbg1(x) cout << #x << " = " << x << ", "
#define dbg2(x) cout << #x << " = " << x << endl

template<typename Tp> IL void read(Tp &x) {
    x=0; int f=1; char ch=getchar();
    while(!isdigit(ch)) {if(ch == ‘-‘) f=-1; ch=getchar();}
    while(isdigit(ch)) {x =  x*10+ch-‘0‘; ch=getchar();}
    x *= f;
}
int buf[22];
template<typename Tp> IL void write(Tp x) {
    int p = 0;
    if(x < 0) {putchar(‘-‘); x=-x;}
    if(x == 0) buf[++p] = 0;
    else while(x) {
        buf[++p] = x % 10;
        x /= 10;
    }
    for(int i=p;i;i--) putchar(‘0‘ + buf[i]);
}

const int mod = 998244353;
const int G = 3;
const int N = 262144;

IL int ksm(int a, int b, int m) {
    int ret = 1;
    while(b) {
        if(b&1) ret = 1ll * ret * a % m;
        a = 1ll * a * a % m;
        b >>= 1;
    }
    return ret;
}
IL int inv(int x) {return ksm(x, mod-2, mod);}

int rev[N], eps[N], ieps[N];
IL void initeps() {
    int g = ksm(G, (mod-1) / N, mod);
    ieps[0] = eps[0] = 1;
    for(int i=1;i<N;i++) eps[i] = 1ll * eps[i-1] * g % mod;
    for(int i=1;i<N;i++) ieps[i] = eps[N-i];
}

IL void trans(int *x, int *w, int n) {
    for(int i=0;i<n;i++) if(i < rev[i]) swap(x[i], x[rev[i]]);
    for(int i=2;i<=n;i<<=1) {
        int l = i >> 1, d = N / i;
        for(int j=0;j<n;j+=i) {
            for(int k=0;k<l;k++) {
                int t = 1ll * x[j+k+l] * w[d*k] % mod;
                x[j+k+l] = (x[j+k] - t + mod) % mod;
                x[j+k] = (x[j+k] + t) % mod;
            }
        }
    }
}

IL void dft(int *x, int n) { trans(x, eps, n);}
IL void idft(int *x, int n) { 
    trans(x, ieps, n);
    int in = inv(n);
    for(int i=0;i<n;i++) x[i] = 1ll * x[i] * in % mod;
}

int ntt_a[N];
IL void polyinv(int *B, int *A, int lenA) {
    if(lenA == 1) {B[0] = inv(A[0]); return;}
    polyinv(B, A, (lenA+1) >> 1);
    int lim = 1, lglim = 0;
    while(lim < (lenA<<1)) { lim <<= 1; lglim++;}
    for(int i=0;i<lim;i++) rev[i] = (rev[i>>1]>>1) | ((i&1) << (lglim-1));
    for(int i=0;i<lenA;i++) ntt_a[i] = A[i];
    for(int i=lenA;i<lim;i++) ntt_a[i] = 0;
    dft(ntt_a, lim); dft(B, lim);
    for(int i=0;i<lim;i++) B[i] = 1ll*(2-1ll*ntt_a[i]*B[i]%mod+mod)%mod*B[i]%mod;
    idft(B, lim);
    for(int i=lenA;i<lim;i++) B[i] = 0;
}

int n;
int f[N], ans[N];

int main() {
    initeps();
    read(n);
    for(int i=0;i<n;i++) read(f[i]);
    polyinv(ans, f, n);
    for(int i=0;i<n;i++) {write(ans[i]); putchar(" \n"[i==n-1]);}
    return 0;
}

跟多项式运算相关代码

标签:lang   c++   idf   cout   orm   amp   pen   pac   cer   

原文地址:https://www.cnblogs.com/bringlu/p/14406143.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!