JZOJ5954. 【NOIP2018模拟11.5A组】走向巅峰

Description

众所周知,DH是一位人生赢家,他不仅能虐暴全场,而且还正在走向人生巅峰;
在巅峰之路上,他碰到了这一题:
给出一棵n个节点的树,我们每次随机染黑一个叶子节点(可以重复染黑),操作无限次后,这棵树的所有叶子节点必然全部会被染成黑色。
定义R为这棵树不经过黑点的直径,求使R第一次变小期望的步数。

Data Constraint

对于15%的数据,满足n<=10;
对于30%的数据,满足n<=1000;
另外有20%数据,满足树为菊花图;
另外有15%数据,满足每个点度数不超过3;
对于100%的数据,满足n<=5*10^5。

题解

考虑直径的长度,如果是奇数就是过定边,偶数过定点。
那么什么时候直径变化呢,
就是只剩下一个集合的时候。
于是枚举一个集合,
枚举它被染黑的个数,
染黑就可以得到概率了。
JZOJ5954. 【NOIP2018模拟11.5A组】走向巅峰

code

#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
using namespace std;
const int N=500003,mo=998244353;
int n,nxt[N*2],to[N*2],lst[N],x,y,d[N],m,sum,tot;
int mx,id,fa[N],w[N],t;
ll jc[N],ny[N],s[N],ans;
char ch;
void read(int&n)
{
	for(ch=getchar();ch<'0' || ch>'9';ch=getchar());
	for(n=0;'0'<=ch && ch<='9';ch=getchar())n=(n<<1)+(n<<3)+ch-48;
}
void write(int x){if(x>9)write(x/10);putchar(x%10+48);}
int max(int x,int y){return x>y?x:y;}
void ins(int x,int y)
{
	nxt[++tot]=lst[x];
	to[tot]=y;
	lst[x]=tot;
}
void dfs(int x,int len)
{
	if(len>mx)mx=len,id=x;
	for(int i=lst[x];i;i=nxt[i])
		if(to[i]^fa[x])fa[to[i]]=x,dfs(to[i],len+1);
}
int work(int x,int fa,int dep)
{
	if(dep==1)return 1;
	int s=0;
	for(int i=lst[x];i;i=nxt[i])
		if(to[i]^fa)s=s+work(to[i],x,dep-1);
	return s;
}
ll ksm(ll x,int y)
{
	ll s=1;
	for(;y;y>>=1,x=x*x%mo)
		if(y&1)s=s*x%mo;
	return s;
}
ll C(int x,int y)
{
	return jc[y]*ny[x]%mo*ny[y-x]%mo;
}
ll get(int v)
{
	ll S=0;
	for(int i=0;i<v;i++)
		S=(S+C(i,v)*jc[sum-v+i-1]%mo*jc[v-i]%mo*(sum-v)%mo*s[sum-v+i])%mo;
	return S;
}
int main()
{
	freopen("winer.in","r",stdin);
	freopen("winer.out","w",stdout);
	read(n);ny[0]=jc[0]=1;
	for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mo;
	ny[n]=ksm(jc[n],mo-2);
	for(int i=n;i;i--)ny[i-1]=ny[i]*i%mo;
	for(int i=1;i<n;i++)
		read(x),read(y),ins(x,y),ins(y,x),d[x]++,d[y]++;
	for(int i=1;i<=n;i++)if(d[i]==1)m++;
	mx=0;dfs(1,0);
	memset(fa,0,sizeof(fa));
	mx=0;dfs(id,0);
	for(int i=1;i<=mx/2;i++)id=fa[id];
	if(mx&1)
	{
		w[t=work(id,fa[id],mx/2+1)]++;
		sum=sum+t;
		w[t=work(fa[id],id,mx/2+1)]++;
		sum=sum+t;
	}
	else
	{
		for(int i=lst[id];i;i=nxt[i])
			w[t=work(to[i],id,mx>>1)]++,sum=sum+t;
	}
	for(int i=1;i<=sum;i++)
		s[i]=(s[i-1]+ksm(sum-i+1,mo-2)*m)%mo;
	for(int i=0;i<=n;i++)if(w[i])ans=(ans+get(i)*w[i])%mo;
	printf("%lld",ans*ksm(jc[sum],mo-2)%mo);
}