BZOJ 4034

树链剖分板子题。。也可以用 DFS序来做

当然无论是 DFS序还是树链都是一个思路

把当前的 树形数据结构转换为 线性数据结构

DFS序更加明显 我们根据先序遍历方法 挨个记录访问的点,当然离开这个点的时候也要记录,可以看作记录先序和后序

举例 BZOJ 4034

那么 DFS序 就是   24771133465562

这样 我们找到相同两个数字之间的全部数字 就是一颗子树了

而对于这道题的 区间修改 区间求和 和单点修改,直接线段树优化,就能成功解决了

当然 这并不是重头戏

接下来我们来深入了解下树链剖分

 

首先 为什么要树链剖分 ?

假设 我们直接的去求任意树子树权值和。 如果暴力的话 肯定 O( n ) ,但是 我们有没有想过,再遍历儿子的过程中,理论上总会有一段是直接到底部叶子节点的,这条链就是一个线性的,如果可以对这个线性结构线段树优化求和 肯定是非常不错的,但是问题也来了,如果去确定这个链子,以什么条件去找出这种可优化结构,是比较困难的。

那么树链剖分就是个非常不错的东西了。(就是找一个规律去对树的儿子进行区分)

以下取自大犇 XLightGod PPT 中讲解:

重儿子:其父亲的子树大小最大的儿子

重边:重儿子与其父亲的连边

轻儿子:如果一个点不是其父亲的重儿子,则其为轻儿子

轻边:轻儿子与其父亲的连边

重链:重边组成的链 一个点到根的路径上所经过的重链和轻边数量均不超过 logn 。

按照重儿子优先的顺序进行 dfs 并依次编号,那么一条重链上的点编 号连续且随深度递增。按照 dfs 序建立线段树。 树上的任何链均可拆分成 O(logn) 条轻边与重链的组合,即各种链上 的查询和修改都只需要在线段树中进行 O(logn) 次操作。 由于是 dfs 序,对于子树操作也同样支持。
实现方法:

首先进行一次 dfs,求出每个点的深度、大小和重儿子等信息。 再按照重儿子优先的顺序进行一次 dfs,确定每个点在线段树中的编号 及所在的重链的顶端结点 top。初始化线段树。(如果是边上的信息, 让每个点表示其父边即可) 如何确定一条链对应哪些重链和轻边? 每次让 top 深度更大的点(注意不是该点本身深度更大)向上跳,如果 其是轻儿子则跳一条轻边,否则跳到所在重链的顶端,沿途进行修改 或者查询操作,直到两个点在同一条重链中。最后处理剩下的这一段。
可以看下模板代码

void dfs1(int x)
{ 
    sz[x]=1; // sz 保存子树大小
    for(int i=b[x];i;i=nxt[i]) // 前向星链表
    { 
        if(to[i]==f[x])continue; 
        f[to[i]]=x;  //确保不会返祖
        dep[to[i]]=dep[x]+1; //深度增加
        dfs1(to[i]); 
        sz[x]+=sz[to[i]]; //更新树的大小
        if(sz[to[i]]>sz[son[x]])
            son[x]=to[i]; //更新重儿子
    } 
} 
void dfs2(int x,int y)
{ 
    top[x]=y; //用 top 来记录重链起点   
    id[x]=++tot; //更新点的编号
    if(son[x])
        dfs2(son[x],y); //重儿子优先 dfs
    for(int i=b[x];i;i=nxt[i])//之后正常 dfs 开始看轻儿子
    { 
        if(to[i]==f[x]||to[i]==son[x])continue; 
        dfs2(to[i],to[i]); 
    } 
} 
int query(int x,int y)//查询操作 (此处指树上两点距离)
{ 
    int ans=0; 
    while(top[x]!=top[y]) //同祖判定
    { 
        if(dep[top[x]]<dep[top[y]]) swap(x,y); 
        ans+=get(id[top[x]],id[x]);  //线段树求和部分
        x=f[top[x]]; 
    } 
    if(dep[x]<dep[y])swap(x,y); 
    ans+=get(id[y],id[x]); 
    return ans; 
}  

然后对于本题,树链剖分部分 AC代码 如下

#include<bits/stdc++.h>
using namespace std;
#define ll long long
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n,m;

int id,pos[100005],mx[100005],v[100005];

int head[100005],tot;
struct edge
{
	int to,nxt;
}e[200005];
void add(int u,int v)
{
	e[++tot] = (edge){v,head[u]}; head[u]=tot;
	e[++tot] = (edge){u,head[v]}; head[v]=tot;
}
int top[100005],size[100005],fa[100005];  
void dfs(int x)
{
	size[x]=1;
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to != fa[x])
		{
			fa[e[i].to] = x;
			dfs(e[i].to);
			size[x] += size[e[i].to];
			mx[x] = max(mx[x],mx[e[i].to]);
		}
}
void dfs2(int x,int cha)
{
	top[x]=cha;pos[x]=mx[x]=++id;
	int k=0;
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to!=fa[x]&&size[e[i].to]>size[k])
			k=e[i].to;
	if(k){ dfs2(k,cha); mx[x]=max(mx[x],mx[k]); }
	for(int i=head[x];i;i=e[i].nxt)
		if(e[i].to!=fa[x]&&e[i].to!=k)
		{
			dfs2(e[i].to,e[i].to);
			mx[x]=max(mx[x],mx[e[i].to]);
		}
}
// 线段树
#define root 1,1,n
#define ls k<<1,l,mid
#define rs k<<1|1,mid+1,r
ll tag[400005],sum[400005];
void pushdown(int l,int r,int k)
{
	if(l==r) return;
	int mid=(l+r)>>1;
	ll t=tag[k]; tag[k]=0;
	tag[k<<1] += t;
	tag[k<<1|1] += t;
	sum[k<<1] += t*(mid-l+1);
	sum[k<<1|1] += t*(r-mid);
}
void insert(int k,int l,int r,int x,int y,ll val)
{
	if(tag[k])pushdown(l,r,k);
	if(l==x&&y==r)
    {
        tag[k]+=val;sum[k]+=(r-l+1)*val;
        return;
    }
	int mid=(l+r)>>1;
	if(x<=mid)
        insert(ls,x,min(mid,y),val);
	if(y>=mid+1)
        insert(rs,max(mid+1,x),y,val);
	sum[k]=sum[k<<1]+sum[k<<1|1];
}
ll query(int k,int l,int r,int x,int y)
{
	if(tag[k])pushdown(l,r,k);
	if(l==x&&y==r)
        return sum[k];
	int mid=(l+r)>>1;
	ll ans=0;
	if(x<=mid)
		ans+=query(ls,x,min(mid,y));
	if(y>=mid+1)
		ans+=query(rs,max(mid+1,x),y);
	return ans;
}
//
ll query(int x)
{
	ll ans=0;
	while(top[x]!=1)
	{
		ans+=query(root,pos[top[x]],pos[x]);
		x=fa[top[x]];
	}
	ans+=query(root,1,pos[x]);
	return ans;
}
int main()
{
	n=read(); m=read();
	for(int i=1;i<=n;i++)
        v[i]=read();
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		add(u,v);
	}
	dfs(1);
	dfs2(1,1);
	for(int i=1;i<=n;i++)
		insert(root,pos[i],pos[i],v[i]);
	int opt,x,a;
	while(m--)
	{
		opt=read();x=read();
		if(opt==1)
		{
			a=read();insert(root,pos[x],pos[x],a);
		}
		if(opt==2)
		{
			a=read();insert(root,pos[x],mx[x],a);
		}
		if(opt==3)printf("%lld\n",query(x));
	}
	return 0;
}