Codeforce 622 F. The Sum of the k-th Powers(拉格朗日插值求k次幂之和,拉格朗日插值公式)

标签: 多项式  拉格朗日插值  k次幂之和

在这里插入图片描述

题目大意:求 i=1nik\displaystyle\sum_{i = 1}^ni^k
求k次幂有多种求法,例如:
伯努利数求k次幂之和(待补)
斯特林数求k次幂之和
拉格朗日插值法求k次幂之和

这里采用拉格朗日插值法进行求解。
拉格朗日可以通过 k+1k + 1 个点唯一确定一个 kk 次多项式,它的公式为:f(x)=i=1ny[i]ijxx[j]x[i]x[j]f(x) = \sum_{i = 1}^ny[i] \prod_{i \neq j}\frac{x - x[j]}{x[i]-x[j]}
其中x[i],y[i]x[i],y[i]对应已知的点值,对已知的点很容易通过代入验证正确性,带入 x[i]x[i] 将会得到 y[i]y[i]

这个式子在一般情况下的复杂度为 O(n2)O(n^2),比高斯消元的 n3n^3 更加优秀,在已知点的 xx 取值连续的情况下,复杂度能降低到 O(n)O(n),只要预处理阶乘逆元,以及 xx 的 k + 1 项倒阶乘:xfacxfac
f(x)=i=1ny[i]xfacfac[i]fac[ni](xi)f(x)=\sum_{i = 1}^ny[i]*\frac{xfac}{fac[i]*fac[n - i]*(x-i)}

为什么这题可以用拉格朗日插值
当然是因为 i=1nik\displaystyle\sum_{i = 1}^ni^k 是一个以n为自变量的多项式,并且是 k+1k + 1 次多项式
证明:
S(n,k)=i=1nik\displaystyle S(n,k)=\sum_{i = 1}^ni^k
对这个序列两两差分可以得到:(n+1)k+1nk+1=i=0k+1C(k+1,i)nink+1=i=0kC(k+1,i)ni(n + 1)^{k+1} - n^{k+1}=\sum_{i = 0}^{k+1}C(k+1,i)*n^i - n^{k+1}=\sum_{i = 0}^kC(k+1,i)*n^ink+1(n1)k+1=i=0kC(k+1,i)(n1)in^{k+1} - (n-1)^{k+1}=\sum_{i = 0}^{k}C(k+1,i)*(n-1)^i......1k+10k+1=i=0kC(k+1,i)0i1^{k+1}-0^{k+1}=\sum_{i = 0}^{k}C(k+1,i)0^i

逐项求和可以得到 (n+1)k+1=i=0kC(k+1,i)S(n,k)\displaystyle (n+1)^{k+1} =\sum_{i=0}^kC(k+1,i)*S(n,k),即得到S(n,k)S(n,k)是以 nn 为自变量的 k+1k + 1 次多项式

f(x)=xkf(x) = x^k,可以得到一个更一般的推广结论:kk 次多项式的前 nn 项和 g(n)g(n) 是一个以 nn 为自变量的 k+1k + 1 次多项式

回到这题,前k+2k + 2项可以 klogkk \log k 暴力计算,对nk+2n \leq k + 2 直接输出答案,对 n>k+2n > k + 2 只要插值一下,根据插值公式计算即可。


代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 1e6 + 100;
int n,k;
int x[maxn],y[maxn];				//拉格朗日差值的计算 
int fac[maxn],ifac[maxn];			//阶乘的逆元 
inline int add(int x, int y) {
  	x += y;
  	if (x >= mod)
    	x -= mod;
  	return x;
}

inline int sub(int x, int y) {
  	x -= y;
  	if (x < 0)
    	x += mod;
  	return x;
}

inline int mul(int x, int y) {
  	return (long long) x * y % mod;
}
int fpow(int a,int b) {
	int r = 1;
	while (b) {
		if (b & 1) r = mul(r,a);
		b >>= 1;
		a = mul(a,a);
	}
	return r;
}
int main() {
	scanf("%d%d",&n,&k);
	for (int i = 1; i <= k + 2; i++) {			//暴力计算 k + 2 个点,根据这 k + 2个点就可以通过插值唯一确定 k + 1次多项式 
		x[i] = i;
		y[i] = add(y[i - 1],fpow(x[i],k));
	}
	if (n <= k + 2) {							//n <= k + 2就直接输出,否则下面的处理会出错 
		printf("%d\n",y[n]);
		return 0;
	}
	fac[0] = 1;
	for (int i = 1; i <= k + 2; i++) {			//由于k+2个点x取值连续,预处理阶乘,使复杂度降低到O(k) 
		fac[i] = mul(fac[i - 1],i);
	}
	ifac[k + 2] = fpow(fac[k + 2],mod - 2);
	for (int i = k + 1; i >= 0; i--) {
		ifac[i] = mul(ifac[i + 1],i + 1);
	}
	int tmp = 1;								//n的倒阶乘 ,同样也是为了加速 
	for (int i = 1; i <= k + 2; i++) {
		tmp = mul(tmp,(n - i) % mod);
	}
	int ans = 0;
	for (int i = 1; i <= k + 2; i++) {			//插值迭代,得到 f(n) 
		int t = k + 2 - i;
		int p = (t & 1) ? -1 : 1;
		int inv = fpow((n - i) % mod,mod - 2);
		int res = mul(mul(ifac[i - 1],ifac[t]),mul(tmp,inv));
		res = mul(mul(res,p),y[i]);
		if (res < 0) res += mod;
		ans = add(ans,res);
	}
	printf("%d\n",ans);
	return 0;
}
版权声明:本文为qq_41997978原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_41997978/article/details/104237443