码迷,mamicode.com
首页 > 其他好文 > 详细

多项式多点求值

时间:2020-01-02 21:00:50      阅读:73      评论:0      收藏:0      [点我收藏+]

标签:printf   复杂度   http   math   部分   using   求值   fas   线段   

给定一个\(n\)次多项式\(A(x)\)\(m\)个值\(a_i\),求出对于\(任意i\in [0,m-1],A(a_i)\)的值

前置知识:

分治FFT

多项式除法

一般优化多项式要么倍增要么分治……

然而这题看上去不像能倍增的亚子,所以就分治吧

考虑先将要求的点分为两部分

\(x[0]=\{x_0,x_1,……,x_{\frac{m}{2}}\},x[1]=\{x_{\frac{m}{2}+1},x_{\frac{m}{2}+2},……x_{m-1}\}\)

我们记\(p[0]=\prod\limits_{i=1}^{\frac{m}{2}}(x-x_i),p[1]=\prod\limits_{i=\frac{m}{2}+1}^{m-1}(x-x_i)\)

显然\(p\)可以用类似线段树建树的方法求出

考虑对\(A(x)\)进行分治

\(A(x)=D(x)p[0](x)+A[0](x)\)

\(x\in x[0]\)的时候,\(A(x)≡A[0](x)\, (mod\, p[0])\)

\(A[0]\)的次数是小于\(p[0]\)

这里\(A[0]\)是可以用多项式除法求出来的

\(A[1]\)同理

\(A\)的次数小于\(100\)的时候其实就可以暴力求解了

时间复杂度\(O(nlog^2n)\)

……其实我自己也觉得没理解透彻,有锅欢迎指出

#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define eps (1e-8)
    inline int read()
    {
        int x=0;char ch,f=1;
        for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
        if(ch=='-') f=0,ch=getchar();
        while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
        return f?x:-x;
    }
    const int N=64444,mod=998244353;
    int n,m,limit,len;
    vector<int> poly[266666],a;
    int pos[266666],b[N],ret[N];
    int g[21][266666];
    inline int fast(int x,int k)
    {
        int ret=1;
        while(k)
        {
            if(k&1) ret=ret*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return ret;
    }
    inline int add(int x,const int &y)//卡常
    {
        x+=y;
        return x>mod?x-mod:x;
    }
    inline int del(int x,const int &y)
    {
        x-=y;
        return x<0?x+mod:x;
    }
    inline void init(int x)//封装
    {
        limit=1,len=0;
        while(limit<(x<<1)) limit<<=1,++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
    }
    inline void ntt(vector<int> &a,int inv)
    {
        while(a.size()<limit) a.push_back(0);
        for(int i=0;i<limit;++i)
            if(i<pos[i]) swap(a[i],a[pos[i]]);
        for(int mid=1,t=1;mid<limit;mid<<=1,++t)
        {
            for(int r=mid<<1,j=0;j<limit;j+=r)
            {
                for(int k=0;k<mid;++k)
                {
                    int x=a[j+k],y=g[t][k]*a[j+k+mid]%mod;
                    a[j+k]=add(x,y);
                    a[j+k+mid]=del(x,y);
                }
            }
        }
        if(inv) return;
        inv=fast(limit,mod-2);reverse(a.begin()+1,a.begin()+limit);
        for(int i=0;i<limit;++i) a[i]=a[i]*inv%mod;
    }
    inline void NTT(vector<int> a,vector<int> b,vector<int> &c)//封装一下短一点
    {
        c.clear();
        ntt(a,1);ntt(b,1);
        for(int i=0;i<limit;++i) c.push_back(a[i]*b[i]%mod);
        ntt(c,0);
    }
    inline void poly_inv(int pw,vector<int> a,vector<int> &B)//多项式乘法逆
    {
        if(pw==1){B.push_back(fast(a[0],mod-2));return;}
        poly_inv((pw+1)>>1,a,B);
        init(pw);
        while(a.size()<limit) a.push_back(0);
        for(int i=pw;i<limit;++i) a[i]=0;
        ntt(a,1);ntt(B,1);
        for(int i=0;i<limit;++i) B[i]=del(2,a[i]*B[i]%mod)*B[i]%mod;
        ntt(B,0);
        for(int i=pw;i<limit;++i) B[i]=0;
    }
    inline void get_poly(int l,int r,int p)//求出p数组
    {
        if(l==r)
        {
            poly[p].push_back(b[l]?mod-b[l]:0);
            poly[p].push_back(1);
            return;
        }
        int mid=(l+r)>>1;
        get_poly(l,mid,ls(p));get_poly(mid+1,r,rs(p));
        init(r-l+1);
        NTT(poly[ls(p)],poly[rs(p)],poly[p]);
    }
    inline void poly_mod(vector<int> a,vector<int> b,vector<int> &d,int n,int m)//多项式取模(除法)
    {
        while(a.size()<=n) a.push_back(0);
        while(b.size()<=m) b.push_back(0);
        if(n<m) return (void)(d=a);
        vector<int> apos,bpos,bposinv,c,cpos;
        d.clear();
        for(int i=0;i<=n;++i) apos.push_back(a[n-i]);
        for(int i=0;i<=m;++i) bpos.push_back(b[m-i]);
        for(int i=n-m+1;i<apos.size();++i) apos[i]=0;
        for(int i=n-m+1;i<bpos.size();++i) bpos[i]=0;
        poly_inv(n-m+1,bpos,bposinv);
        init(n-m+1);
        NTT(apos,bposinv,cpos);
        for(int i=0;i<=n-m;++i) c.push_back(cpos[n-m-i]);
        init(n);
        NTT(b,c,d);
        for(int i=0;i<m;++i) d[i]=del(a[i],d[i]);
        for(int i=m;i<limit;++i) d[i]=0;
    }
    inline void solve(vector<int> a,int p,int l,int r)
    {
        if(r-l<=100)
        {
            for(int i=l;i<=r;++i)
            {
                int s=0;
                for(int j=a.size()-1;~j;--j)
                {
                    s=add(s*b[i]%mod,a[j]);
                }
                ret[i]=s;
            }
            return;
        }
        vector<int> b;
        int mid=(l+r)>>1;
        poly_mod(a,poly[ls(p)],b,r-l,mid-l+1);
        solve(b,ls(p),l,mid);
        poly_mod(a,poly[rs(p)],b,r-l,r-mid);
        solve(b,rs(p),mid+1,r);
    }
    inline void main()
    {
        n=read(),m=read();
        for(int mid=1,t=1;mid<266666;mid<<=1,++t)//预处理原根,稍微快一点
        {
            g[t][0]=1;int Wn=fast(3,(mod-1)/(mid<<1));
            for(int k=1;k<mid;++k)
            {
                g[t][k]=g[t][k-1]*Wn%mod;
            }
        }
        for(int i=0;i<=n;++i) a.push_back(read());
        for(int i=1;i<=m;++i) b[i]=read();
        get_poly(1,m,1);
        poly_mod(a,poly[1],a,n,m);
        solve(a,1,1,m);
        for(int i=1;i<=m;++i) printf("%lld\n",ret[i]);
    }
}
signed main()
{
    red::main();
    return 0;
}

多项式多点求值

标签:printf   复杂度   http   math   部分   using   求值   fas   线段   

原文地址:https://www.cnblogs.com/knife-rose/p/12139633.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!