[JZOJ6050]【NOI2019模拟2019.3.11】树上四次求和 【数据结构】【树链剖分】【点分治】

Description

[JZOJ6050]【NOI2019模拟2019.3.11】树上四次求和 【数据结构】【树链剖分】【点分治】
[JZOJ6050]【NOI2019模拟2019.3.11】树上四次求和 【数据结构】【树链剖分】【点分治】

Solution

一道很简单的题,自己没有想出来,相当不应该
一开始我先想的假如直接询问w(i,j)怎么做,但是完全没有思路(后来发现询问w(i,j)比这题难的多)

首先那个排列a似乎没什么用,那么以下的表示都忽略掉(把x看做a[x])
先化一波式子,可以得到这个结果
ansk=x=1ky=xkdis(x,y)×x(ky+1)ans_k=\sum\limits_{x=1}^{k}\sum\limits_{y=x}^{k}dis(x,y)\times x(k-y+1)

我们令Si=x=1idis(x,i)×xS_i=\sum\limits_{x=1}^{i}dis(x,i)\times x,假如求出了Si,再做一个Si的前缀和和Si*i的前缀和就可以算出ans_k了,这是线性的。

问题在于如何求Si

一种方法是轻重链剖分
我们将两点距离dis(x,y)dis(x,y)拆成dep[x]+dep[y]2dep[lca]dep[x]+dep[y]-2dep[lca]的形式

那么有Si=x=1i(dep[x]+dep[i]2dep[lca(x,i)])×xS_i=\sum\limits_{x=1}^{i}(dep[x]+dep[i]-2dep[lca(x,i)])\times x
其中dep[x]*x以及dep[i]*x的和都是容易计算的

现在就是要算x=1i2dep[lca(x,i)]×x\sum\limits_{x=1}^{i}-2dep[lca(x,i)]\times x

考虑dep[lca]等于什么,它显然可以看做是lca到根路径上的点的个数
从左到右扫,每扫到一个点i,就在它祖先到根的链上每个节点打上+i的标记,那么后面的点求dep[lca]就只需要祖先到根的路径求和即可,这个就用轻重链剖分+线段树维护。

另一种方法是点分治。
考虑直接计算所有的SiS_i

对于每个分治中心,将分治子树中的所有点拉出来排序,从前向后一个个加,这个是容易计算的。
此时还要减掉同一个子树中的,那么再分别将每个子树同样做一遍减去即可。

Code

#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 200005
#define LL long long
#define mo 998244353
using namespace std;
int fs[N],nt[2*N],dt[2*N],n,q,a[N],m1,top[N],sz[N],son[N],dfn[N],n1,t[N][2],dep[N],ft[N];
LL s1[N],s2[N],sm[N],lz[N];
void link(int x,int y)
{
	nt[++m1]=fs[x];
	dt[fs[x]=m1]=y;
}
void dfs(int k,int fa)
{
	sz[k]=1;
	ft[k]=fa;
	dep[k]=dep[fa]+1;
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa) 
		{
			dfs(p,k);
			sz[k]+=sz[p];
			if(sz[p]>sz[son[k]]) son[k]=p;
		}
	}	
}
void make(int k,int fa)
{
	dfn[k]=++dfn[0];
	if(son[k]) top[son[k]]=top[k],make(son[k],k);
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa&&p!=son[k]) top[p]=p,make(p,k);
	}
}
void build(int k,int l,int r)
{
	if(l==r) return;
	int mid=(l+r)>>1;
	t[k][0]=++n1,build(n1,l,mid);
	t[k][1]=++n1,build(n1,mid+1,r);
}
void down(int k,LL le,LL re)
{
	if(lz[k])
	{
		lz[t[k][0]]=(lz[t[k][0]]+lz[k])%mo;
		lz[t[k][1]]=(lz[t[k][1]]+lz[k])%mo;
		sm[t[k][0]]=(sm[t[k][0]]+lz[k]*le)%mo;
		sm[t[k][1]]=(sm[t[k][1]]+lz[k]*re)%mo;
		lz[k]=0;
	}
}
void add(int k,int l,int r,int x,int y,LL v)
{
	if(x>y||y<l||x>r) return;
	if(x<=l&&r<=y) lz[k]=(lz[k]+v)%mo,sm[k]=(sm[k]+(LL)(r-l+1)*v)%mo;
	else
	{
		int mid=(l+r)>>1;
		down(k,mid-l+1,r-mid);
		add(t[k][0],l,mid,x,y,v);
		add(t[k][1],mid+1,r,x,y,v);
		sm[k]=(sm[t[k][0]]+sm[t[k][1]])%mo;
	}
}
LL query(int k,int l,int r,int x,int y)
{
	if(x>y||y<l||x>r||!sm[k]) return 0;
	if(x<=l&&r<=y) return sm[k];
	int mid=(l+r)>>1;
	down(k,mid-l+1,r-mid);
	return (query(t[k][0],l,mid,x,y)+query(t[k][1],mid+1,r,x,y))%mo;
}
void put(int k,LL v)
{
	while(k)
	{
		add(1,1,n,dfn[top[k]],dfn[k],v);
		k=ft[top[k]];
	}
}
LL get(int k)
{
	LL s=0;
	while(k)
	{
		s=(s+query(1,1,n,dfn[top[k]],dfn[k]))%mo;
		k=ft[top[k]];
	}
	return s;
}
int main()
{
	cin>>n>>q;
	bool pd=1;
	fo(i,1,n-1)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		link(x,y),link(y,x);
	}
	fo(i,1,n) scanf("%d",&a[i]);
	dfs(1,0);
	top[1]=1;
	make(1,0);
	n1=1;
	build(1,1,n);
	LL sp=0,si=0;
	fo(i,1,n)
	{
		s1[i]=(si*(LL)dep[a[i]]%mo+sp-(LL)2*get(a[i])+mo+mo)%mo;
		put(a[i],i);
		s2[i]=(s2[i-1]+s1[i]*(LL)i)%mo;
		s1[i]=(s1[i-1]+s1[i])%mo;
		si=(si+i)%mo;
		sp=(sp+(LL)i*(LL)dep[a[i]])%mo;
	}
	fo(i,1,q)
	{
		int x;
		scanf("%d",&x);
		printf("%lld\n",(s1[x]*(LL)(x+1)-s2[x]+mo)%mo);
	}
}