标签:
Description
Edward has a tree with n vertices conveniently labeled with 1,2,…,n.
Edward finds a pair of paths on the tree which share no more than k common vertices. Now Edward is interested in the number of such ordered pairs of paths.
Note that path from vertex a to b is the same as the path from vertex b to a. An ordered pair means (A, B) is different from (B, A) unlessA is equal to B.
Input
There are multiple test cases. The first line of input contains an integer T indicating the number of test cases. For each test case:
The first line contains two integers n, k (1 ≤ n, k ≤ 88888). Each of the following n - 1 lines contains two integers ai, bi, denoting an edge between vertices ai and bi (1 ≤ ai, bi ≤ n).
The sum of values n for all the test cases does not exceed 888888.
Output
For each case, output a single integer denoting the number of ordered pairs of paths sharing no more than k vertices.
Sample Input
1 4 2 1 2 2 3 3 4
Sample Output
| path A | paths share 2 vertices with A | total | 
|---|---|---|
| 1-2-3-4 | 1-2, 2-3, 3-4 | 3 | 
| 1-2-3 | 1-2, 2-3, 2-3-4 | 3 | 
| 2-3-4 | 1-2-3, 2-3, 3-4 | 3 | 
| 1-2 | 1-2, 1-2-3, 1-2-3-4 | 3 | 
| 2-3 | 1-2-3, 1-2-3-4, 2-3, 2-3-4 | 4 | 
| 3-4 | 1-2-3-4, 2-3-4, 3-4 | 3 | 
93
The number of path pairs that shares no common vertex is 30.
The number of path pairs that shares 1 common vertex is 44.
The number of path pairs that shares 2 common vertices is 19.
这种题对于脑细胞的损耗实在有点大,wa了10发才过,实在是有点疲惫。
树分治,然后每个点统计边的种类,细节处理有点复杂。
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef unsigned long long LL;
const int low(int x) { return x&-x; }
const int maxn = 3e5 + 10;
const int INF = 0x7FFFFFFF;
int T, n, m, x, y;
struct Tree
{
	int ft[maxn], nt[maxn], u[maxn], sz;
	int vis[maxn], mx[maxn], ct[maxn];
	LL d[maxn], D[maxn];
	void clear(int n)
	{
		mx[sz = 0] = INF;
		for (int i = 1; i <= n; i++) vis[i] = 0, ft[i] = -1;
	}
	void AddEdge(int x, int y)
	{
		u[sz] = y;	nt[sz] = ft[x];	ft[x] = sz++;
		u[sz] = x;	nt[sz] = ft[y]; ft[y] = sz++;
	}
	int dfs(int x, int fa, int sum)
	{
		int y = mx[x] = (ct[x] = 1) - 1;
		for (int i = ft[x]; i != -1; i = nt[i])
		{
			if (vis[u[i]] || u[i] == fa) continue;
			int z = dfs(u[i], x, sum);
			ct[x] += ct[u[i]];
			mx[x] = max(mx[x], ct[u[i]]);
			y = mx[y] < mx[z] ? y : z;
		}
		mx[x] = max(mx[x], sum - ct[x]);
		return mx[x] < mx[y] ? x : y;
	}
	int getdep(int x, int fa, int dep)
	{
		int ans = dep;
		for (int i = ft[x]; i != -1; i = nt[i])
		{
			if (u[i] == fa || vis[u[i]]) continue;
			ans = max(ans, getdep(u[i], x, dep + 1));
		}
		return ans;
	}
	LL get(int x, int fa, int dep)
	{
		LL cnt = 1, ans = 0;
		for (int i = ft[x]; i != -1; i = nt[i])
		{
			if (u[i] == fa) continue;
			LL y = vis[u[i]] ? mx[u[i]] : get(u[i], x, dep + 1);
			cnt += y;	ans += y*y;
		}
		D[dep] += cnt*cnt - ans;
		if (dep == 1) return ans;
		return cnt;
	}
	LL find(int x)
	{
		LL ans = 0, sum = 0, tot = 0;
		int len = getdep(x, -1, 1);
		if (len + len <= m + 1) return 0;
		sum = get(x, -1, 1);
		for (int i = 1; i <= len; i++) d[i] = 0;
		for (int i = ft[x]; i != -1; i = nt[i])
		{
			if (vis[u[i]]) continue;
			int y = getdep(u[i], x, 2);
			for (int j = 2; j <= y; j++) D[j] = 0;
			LL z = get(u[i], x, 2);
			for (int j = 2; j <= y; j++)
			{
				LL s = 0;
				for (int k = min(m + 1 - j, len); k > 0; k -= low(k)) s += d[k];
				ans += D[j] * (tot - s);
			}
			for (int j = 2; j <= y; j++)
			{                                                                                                      
				for (int k = j; k <= len; k += low(k)) d[k] += D[j];
				if (j > m) ans += (((LL)n - z)*((LL)n - z) + z*z - sum)*D[j];
				tot += D[j];
			}
		}
		return ans;
	}
	int dfs(int x,int fa)
	{
		int cnt=1;
		for (int i=ft[x];i!=-1;i=nt[i])
		{
			if (u[i]==fa) continue;
			cnt+=vis[u[i]]?mx[u[i]]:dfs(u[i],x);
		}
		return cnt;
	}
	LL work(int x, int sum)
	{
		int y = dfs(x, -1, sum);
		LL ans = find(y);	vis[y] = 1;
		for (int i = ft[y]; i != -1; i = nt[i])
		{
			if (vis[u[i]]) continue;
			mx[y] = n-dfs(u[i],y);
			ans += work(u[i], ct[u[i]] > ct[y] ? sum - ct[y] : ct[u[i]]);
		}
		return ans;
	}
}solve;
int main()
{
	scanf("%d", &T);
	while (T--)
	{
		scanf("%d%d", &n, &m);
		LL ans = (LL)n * (n + 1) >> 1;
		solve.clear(n);
		for (int i = 1; i < n; i++)
		{
			scanf("%d%d", &x, &y);
			solve.AddEdge(x, y);
		}
		printf("%llu\n", ans*ans - solve.work(1, n));
	}
	return 0;
}标签:
原文地址:http://blog.csdn.net/jtjy568805874/article/details/51367412