【JZOJ6053】Mas的仙人掌
Description
Solution
直接计算每条非树边合法出现的概率,它的贡献就是它的两端点之间的路径与其它非树边的路径没有边相交的概率。
考虑如何计算,首先乘上这条边不脱落的概率,然后乘上所有其它与其有交的非树边路径脱落的概率。可以在树上打标记,每条路径的标记如何只算一次?注意到一条路径的边数减去长度为2的简单路径数为1,我们可以分别打上标记,注意lca处的路径为2的数目要统计一下。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<map>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
const int N=1e6+10,M=2e6+10,mo=998244353;
int to[M],nx[M],ls[N],vl[M],num=0;
void link(int u,int v){
to[++num]=v,nx[num]=ls[u],ls[u]=num;
}
int read(){
char ch=' ';int t=0;
for(;ch<'0' || ch>'9';ch=getchar());
for(;ch>='0' && ch<='9';ch=getchar()) t=(t<<1)+(t<<3)+ch-48;
return t;
}
int pow(int x,int y){
int s=1;
for(;y;y>>=1,x=(ll)x*x%mo) if(y&1) s=(ll)s*x%mo;
return s;
}
int lg[N],dep[N];
int f[N][21],tot=0;
struct node{
int u,v,w;
int lc,u1,v1;
}e[N];
struct P{
int x,y;
P(int _x=1,int _y=0) {x=_x,y=_y;}
int f() {return y==0?x:0;}
}f1[N],f2[N];
P operator *(P x,P y) {return P((ll)x.x*y.x%mo,x.y+y.y);}
P operator /(P x,P y) {return P((ll)x.x*pow(y.x,mo-2)%mo,x.y-y.y);}
P operator *(P x,int y){
return !y?P(x.x,x.y+1):P((ll)x.x*y%mo,x.y);
}
P operator /(P x,int y){
return !y?P(x.x,x.y-1):P((ll)x.x*pow(y,mo-2)%mo,x.y);
}
map<pr,P> mp;
int z[N];
void pre(int x,int fr){
f[x][0]=fr,dep[x]=dep[fr]+1;
fo(i,1,lg[dep[x]]) f[x][i]=f[f[x][i-1]][i-1];
z[++tot]=x;
rep(i,x){
int v=to[i];
if(v==fr) continue;
pre(v,x);
}
}
int u1,v1;
int lca(int u,int v){
int t=lg[dep[u]];
fd(i,t,0) if(dep[f[u][i]]>dep[v]) u=f[u][i];
u1=v1=u;
if(dep[u]>dep[v]) u=f[u][0];
if(u==v) return u;
fd(i,t,0) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
u1=u,v1=v;
return f[u][0];
}
void make(int x,int y,int w){
if(x>y) swap(x,y);
pr t=make_pair(x,y);
if(!mp.count(t)) mp[t]=P(1,0);
mp[t]=mp[t]/w;
}
P get(int x,int y){
if(x>y) swap(x,y);
pr t=make_pair(x,y);
if(!mp.count(t)) return P(1,0);
return mp[t];
}
int main()
{
freopen("cactus.in","r",stdin);
freopen("cactus.out","w",stdout);
int n=read(),m=read();
fo(i,2,n){
int u=read(),v=read();
link(u,v),link(v,u);
}
fo(i,2,n) lg[i]=lg[i>>1]+1;
pre(1,0);
fo(i,1,m){
int u=read(),v=read(),w=read();
if(dep[u]<dep[v]) swap(u,v);
e[i].u=u,e[i].v=v,e[i].w=w;
int lc=lca(u,v);
e[i].u1=u1,e[i].v1=v1,e[i].lc=lc;
f1[u]=f1[u]*w,f1[v]=f1[v]*w,f1[lc]=f1[lc]/w/w;
if(lc!=v) f2[u]=f2[u]/w,f2[v]=f2[v]/w,f2[u1]=f2[u1]*w,f2[v1]=f2[v1]*w,make(u1,v1,w);
else f2[u]=f2[u]/w,f2[u1]=f2[u1]*w;
}
fd(i,n,2){
int x=z[i],fr=f[x][0];
f1[fr]=f1[fr]*f1[x],f2[fr]=f2[fr]*f2[x];
}
fo(i,2,n){
int x=z[i],fr=f[x][0];
f1[x]=f1[fr]*f1[x],f2[x]=f2[fr]*f2[x];
}
ll ans=0;
fo(i,1,m){
node p=e[i];
P t=f1[p.u]*f1[p.v]/f1[p.lc]/f1[p.lc];
if(p.lc!=p.v) t=t*f2[p.u]/f2[p.u1]*f2[p.v]/f2[p.v1]*get(p.u1,p.v1);
else t=t*f2[p.u]/f2[p.u1];
t=t/p.w;
ans=(ans+(ll)t.f()*(mo+1-p.w))%mo;
}
printf("%lld",ans);
}