UOJ#351.新年的叶子
瞎bb
noip全真模拟赛又挂了。。出题人居然又贺了三道原题。。
T3.走向巅峰新年的叶子//原题链接
被出题人魔改之后的题面…T1暴力T2爆蛋,,于是只好来做T3
思路
树的多条直径一定会相交 所以我们用最暴力的做法(去考提高的应该都会吧 先随便选一个点 找到离这个点最远的一些点 作为直径的左端点们 在随便选一个左端点找到与她最远的一些点 也就是右端点们 然后再树上乱搞即可)算出这段区间的左右两个端点
- 如果,所以就会有上面讲的直径的左端点们和右端点们,于是就在叶子结点和一堆端点之间做期望dp!然后发现只会的dp。。再见!!!!
- 第一种情况暴力但第二种情况总会好考虑一些。吧?如果,类似于菊花图,在个叶子结点中选个(最后只能剩下一个直径的端点)简单的暴力计数(求个结点 还有个结点未被染黑 再染黑一个的期望步数是)!!!!!然而考后大佬又给出了反例。。。。
wa的一声就哭了。。这不是要炸的节奏吗??所以上面那个思路是假的
真·思路
反正出题人也搬了原题 所以我也去学(hè)了题解
其实直径还有一个特别好的性质,就是树的每条直径的中点都是在同一个点上的(证明略,形象理解一下就行quq)
- 如果直径的长度是偶数 那么中点一定是在树上的某个点上的 我们只需要把这个点拎到root上 于是几个直径的端点(深度为)就被划分到了几个不同的集合 窝门只需要各个区间求期望就好了
- 如果直径的长度是奇数 那中点不是在树边上了吗??其实没有关系我们假装那有个点就好了 于是类似于第一种情况 但是发现集合只剩下两个了
我们每次都枚举一个集合,算出其他集合全部被染黑需要的期望时间,再把这些期望时间加起来,就相当于全部的点被染黑了(集合数-1)次,所以窝门再把这个期望时间和染黑整个端点的集合的期望时间(集合数-1),这个数就是啦
最后再加一个特别重要的预处理:的逆元
系不系简单粗暴又好打ヽ( ̄▽ ̄)ノ
Code
还有AC代码是从原来的zz代码魔改过来的 奇丑无比 所以大佬别打我
#include <cstdio>
#include <algorithm>
#define MOD 998244353
#define N 500005
using namespace std;
typedef long long LL;
struct Node {
int to, nxt;
}e[N << 1];
int cnt, lst[N], d[N], du[N], st[N], maxi, leaves, tot, d1[N];
LL pre_inv[N];
LL dp[N];
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = lst[u];
lst[u] = cnt;
}
inline LL qui_pow(LL x, int y) {
if (y == 1) return x;
LL t = qui_pow(x, y / 2);
if (y & 1) return t * t % MOD * x % MOD;
else return t * t % MOD;
}
inline void dfs(int x, int fa, int dep) {
d[x] = dep;
if (d[x] > d[maxi]) maxi = x;
for (int i = lst[x]; i; i = e[i].nxt) {
if (e[i].to == fa) continue;
dfs(e[i].to, x, dep + 1);
}
}
inline int countt(int x, int fa, int len) {
if (du[x] == 1 && d[x] == len) return 1;
int sum = 0;
for (int i = lst[x]; i; i = e[i].nxt) {
if (e[i].to == fa) continue;
sum += countt(e[i].to, x, len);
}
return sum;
}
int main() {
int n, u, v, f = 0;
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
du[u]++;
du[v]++;
add(u, v);
add(v, u);
}
LL inv;
for (int i = 1; i <= n; ++i) {
inv = qui_pow(i, MOD - 2);
pre_inv[i] = (pre_inv[i - 1] + inv) % MOD;
}
for (int i = 1; i <= n; ++i) {
if (du[i] == 1) leaves++;
}
maxi = 0;
dfs(1, 1, 0);
int x = maxi;
maxi = 0;
dfs(x, x, 0);
for (int i = 1; i <= n; ++i) {
d1[i] = d[i];
}
x = maxi;
maxi = 0;
dfs(x, x, 0);
int dia = d[maxi], mid, md, all = 0;
if (dia & 1) {
for (int i = 1; i <= n; ++i) {
if (d[i] == (dia >> 1) && d1[i] == (dia >> 1) + 1) mid = i;
if (d[i] == (dia >> 1) + 1 && d1[i] == (dia >> 1)) md = i;
}
// printf("%d %d\n", mid, md);
dfs(mid, mid, 0);
int num = countt(mid, md, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
// printf("%d\n", num);
dfs(md, md, 0);
num = countt(md, mid, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
// printf("%d\n", num);
}
else {
for (int i = 1; i <= n; ++i) {
if (d[i] == (dia >> 1) && d1[i] == (dia >> 1)) mid = i;
}
dfs(mid, mid, 0);
for (int i = lst[mid]; i; i = e[i].nxt) {
int num = countt(e[i].to, mid, (dia >> 1));
if (num > 0) st[++tot] = num;
all += num;
}
}
LL ans = 0;
for (int i = 1; i <= tot; ++i) {
ans += pre_inv[all - st[i]];
if (ans >= MOD) ans -= MOD;
}
ans -= 1LL * (tot - 1) * pre_inv[all] % MOD;
if (ans < 0) ans += MOD;
ans = ans * leaves % MOD;
printf("%lld\n", ans);
return 0;
}