标签:== ... string ++ a* 多点 -- limit bst
被神秘力量驱使去学这个东西...
由$(x_i,y_i)(0\leq i\leq n)$构造多项式$L(x)=\sum\limits_{i=0}^ny_i\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}\frac{x-x_j}{x_i-x_j}$
观察$L(x_k)$里的$i,j$,如果$i=k$,那么右边的大pi每一项的分子分母都相同,值为$1$,否则右边的大pi在$j=k$时存在一个分子为$0$,这个多项式对$\forall0\leq i\leq n$满足$L(x_i)=y_i$
大pi中,分母是常数,所以我们要对每个$i$求出$\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)$,直接对$\frac{\mathrm d}{\mathrm dx}\prod\limits_{i=0}^n(x-x_i)$在所有$x_i$处多点求值即可
分子用分治求就可以了,合并的时候(左边的答案乘右边的$\prod(x-x_i)$)加上(左边的$\prod(x-x_i)$乘右边的答案)即可,递归到底层返回预处理好的$\dfrac{y_i}{\prod\limits_{\substack{0\leq j\leq n\\i\ne j}}(x_i-x_j)}$就可以了
总时间复杂度$O(n\log_2^2n)$,空间复杂度$O(n\log_2n)$,常数巨大,我写得太丑了,$n=50000$要跑$3$秒多
#include<stdio.h>
#include<string.h>
typedef long long ll;
const int mod=998244353,maxn=131072;
void swap(int&a,int&b){
int c=a;
a=b;
b=c;
}
int max(int a,int b){return a>b?a:b;}
int mul(int a,int b){return a*(ll)b%mod;}
int ad(int a,int b){return(a+b)%mod;}
int de(int a,int b){return(a-b)%mod;}
int pow(int a,int b){
int s=1;
while(b){
if(b&1)s=mul(s,a);
a=mul(a,a);
b>>=1;
}
return s;
}
int rev[maxn],N,iN;
void pre(int n){
int i,k;
for(N=1,k=0;N<n;N<<=1)k++;
for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
iN=pow(N,mod-2);
}
void ntt(int*a,int on){
int i,j,k,t,w,wn;
for(i=0;i<N;i++){
if(i<rev[i])swap(a[i],a[rev[i]]);
}
for(i=2;i<=N;i<<=1){
wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i));
for(j=0;j<N;j+=i){
w=1;
for(k=0;k<i>>1;k++){
t=mul(w,a[i/2+j+k]);
a[i/2+j+k]=de(a[j+k],t);
a[j+k]=ad(a[j+k],t);
w=mul(w,wn);
}
}
}
if(on==-1){
for(i=0;i<N;i++)a[i]=mul(a[i],iN);
}
}
int t0[maxn];
void getinv(int*a,int*b,int n){
if(n==1){
b[0]=pow(a[0],mod-2);
return;
}
int i;
getinv(a,b,n>>1);
pre(n<<1);
memset(t0,0,N<<2);
memcpy(t0,a,n<<2);
ntt(t0,1);
ntt(b,1);
for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(t0[i],b[i]));
ntt(b,-1);
for(i=n;i<N;i++)b[i]=0;
}
int ta[maxn],tb[maxn],tc[maxn];
#define clr while(k!=0&&tc[k]==0)k--; memcpy(c,tc,(k+1)<<2);
void add(int*a,int n,int*b,int m,int*c,int&k){
k=max(n,m);
memset(ta,0,(k+1)<<2);
memcpy(ta,a,(n+1)<<2);
memset(tb,0,(k+1)<<2);
memcpy(tb,b,(m+1)<<2);
for(int i=0;i<=k;i++)tc[i]=ad(ta[i],tb[i]);
clr
}
void dec(int*a,int n,int*b,int m,int*c,int&k){
k=max(n,m);
memset(ta,0,(k+1)<<2);
memcpy(ta,a,(n+1)<<2);
memset(tb,0,(k+1)<<2);
memcpy(tb,b,(m+1)<<2);
for(int i=0;i<=k;i++)tc[i]=de(ta[i],tb[i]);
clr
}
void dif(int*a,int n,int*c,int&k){
k=n-1;
for(int i=1;i<=n;i++)c[i-1]=mul(i,a[i]);
}
void reverse(int*a,int n){
for(int i=0;i<=n>>1;i++)swap(a[i],a[n-i]);
}
void mul(int*a,int n,int*b,int m,int*c,int&k){
int i;
k=n+m;
pre(k+1);
memset(ta,0,N<<2);
memcpy(ta,a,(n+1)<<2);
memset(tb,0,N<<2);
memcpy(tb,b,(m+1)<<2);
ntt(ta,1);
ntt(tb,1);
for(i=0;i<N;i++)tc[i]=mul(ta[i],tb[i]);
ntt(tc,-1);
clr
}
int t1[maxn];
void div(int*a,int n,int*b,int m,int*c,int&k){
if(n<m){
k=0;
return;
}
int i,rn;
for(rn=1;rn<n-m+1;rn<<=1);
memset(ta,0,rn<<3);
memcpy(ta,a,(n+1)<<2);
memset(tb,0,rn<<3);
memcpy(tb,b,(m+1)<<2);
reverse(tb,m);
for(i=rn;i<=m;i++)tb[i]=0;
memset(t1,0,rn<<3);
getinv(tb,t1,rn);
pre(rn<<1);
reverse(ta,n);
for(i=rn;i<=n;i++)ta[i]=0;
ntt(ta,1);
ntt(t1,1);
for(i=0;i<N;i++)tc[i]=mul(ta[i],t1[i]);
ntt(tc,-1);
k=n-m;
reverse(tc,k);
clr
}
void modulo(int*a,int n,int*b,int m,int*c,int&k){
if(n<m){
k=n;
memcpy(c,a,(n+1)<<2);
return;
}
div(a,n,b,m,t1,k);
mul(t1,k,b,m,t1,k);
dec(a,n,t1,k,c,k);
}
int X[50010],Y[50010],*tr[200010],go;
void build(int l,int r,int x){
if(l==r){
tr[x]=new int[2];
tr[x][1]=1;
tr[x][0]=-X[l];
return;
}
int mid=(l+r)>>1;
build(l,mid,x<<1);
build(mid+1,r,x<<1|1);
tr[x]=new int[r-l+2];
mul(tr[x<<1],mid-l+1,tr[x<<1|1],r-mid,tr[x],go);
}
void solve(int*f,int n,int l,int r,int x,int*ans){
int mid=(l+r)>>1,*now;
now=new int[r-l+1];
modulo(f,n,tr[x],r-l+1,now,n);
if(l==r){
ans[l]=now[0];
return;
}
solve(now,n,l,mid,x<<1,ans);
solve(now,n,mid+1,r,x<<1|1,ans);
}
int di[50010],res[50010];
int*solve(int l,int r,int x){
int mid=(l+r)>>1,*res,n1,n2,*t1,*t2;
res=new int[r-l+1];
if(l==r){
res[0]=Y[l];
return res;
}
t1=new int[r-mid+1];
mul(solve(l,mid,x<<1),mid-l,tr[x<<1|1],r-mid,t1,n1);
t2=new int[r-mid+1];
mul(tr[x<<1],mid-l+1,solve(mid+1,r,x<<1|1),r-mid-1,t2,n2);
add(t1,n1,t2,n2,res,go);
return res;
}
int main(){
int n,i,*ans;
scanf("%d",&n);
for(i=0;i<=n;i++)scanf("%d%d",X+i,Y+i);
build(0,n,1);
dif(tr[1],n+1,di,i);
solve(di,n,0,n,1,res);
for(i=0;i<=n;i++)Y[i]=mul(Y[i],pow(res[i],mod-2));
ans=solve(0,n,1);
for(i=0;i<=n;i++)printf("%d ",ad(ans[i],mod));
}
标签:== ... string ++ a* 多点 -- limit bst
原文地址:https://www.cnblogs.com/jefflyy/p/9203230.html