Subway Lines(树上两条路的交点数)
原题: https://cn.vjudge.net/problem/Gym-101908L
题意: 给出一棵树,n节点,每次询问给两对叶子,求这两对叶子产生路径的交集
解析:
找被走过两次的点->走被走过两次的所有Lca,Lca所构成的那一段长度就是点的数量
显然,目标线段的端点一定是这些叶子节点的某个Lca
步骤:
- 找到所有Lca,放入set
- 统计哪些Lca被走过两次
- 怎么判断走过几次:一对叶子(x1,x2)与它的节点l1,假设某个点在x1到l1之间或x2到l2之间,则在路径上
- 怎么判断在x1和l1之间:
Lca(x1,p)==p&&Lca(p,l1)==l1
- 对于走过两次的点,求出组成路径的长度
- 怎么求路径:首先是所有点的Lca,作为Root,删除除了最底下点以外的所有点
- 怎么删除:
if(Lca(p1,p2)==p2)Erase(p2)
#include<bits/stdc++.h>
using namespace std;
#define maxn 500005
int head[maxn<<1],fa[maxn][25],vis[maxn],cnt,dep[maxn],dis[maxn];
struct node
{
int to,next,wei;
}e[maxn<<1];
void add(int x,int y)
{
e[cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt++;
}
void bfs()
{
fa[1][0]=1;
dep[1]=0;
dis[1]=0;
queue<int>Q;
Q.push(1);
while(!Q.empty())
{
int u,v;
u=Q.front();
Q.pop();
for(int i=1;i<=16;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];~i;i=e[i].next)
{
v=e[i].to;
if(v==fa[u][0])
continue;
dis[v]=dis[u]+e[i].wei;
dep[v]=dep[u]+1;
fa[v][0]=u;
Q.push(v);
}
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=16;i>=0;i--)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
for(int i=16;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
set<int>st;
map<int,int>mp;
vector<int>fin;
int main()
{
int n,q;
scanf("%d%d",&n,&q);
int l,r;
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
scanf("%d%d",&l,&r),add(l,r),add(r,l);
bfs();
while(q--)
{
st.clear();
fin.clear();
mp.clear();
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
st.insert(lca(a,b)),st.insert(lca(a,c)),st.insert(lca(a,d));
st.insert(lca(b,c)),st.insert(lca(b,d)); st.insert(lca(c,d));
int l1=lca(a,b),l2=lca(c,d);
for(set<int>::iterator it=st.begin();it!=st.end();it++)
{
int p=*it;
if((lca(a,p)==p&&lca(p,l1)==l1)||(lca(b,p)==p&&lca(p,l1)==l1))
mp[p]++;
}
for(set<int>::iterator it=st.begin();it!=st.end();it++)
{
int p=*it;
if((lca(c,p)==p&&lca(p,l2)==l2)||(lca(d,p)==p&&lca(p,l2)==l2))
mp[p]++;
}
int ans=0;
for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
{
if(it->second>=2)
fin.push_back(it->first);
}
if(fin.size()==0)
{
printf("0\n");
continue;
}
if(fin.size()==1)
{
printf("1\n");
continue;
}
int l3,deep=1e9;
for(int i=0;i<fin.size();i++)
{
if(dep[fin[i]]<deep)
l3=fin[i],deep=dep[fin[i]];
}
while(fin.size()>2)
{
for(int i=0;i<fin.size()-1;i++)
{
for(int j=1;j<fin.size();j++)
{
if(lca(fin[i],fin[j])==fin[i])
fin.erase(fin.begin()+i);
else if(lca(fin[i],fin[j])==fin[j])
fin.erase(fin.begin()+j);
}
}
}
if(lca(fin[0],fin[1])==fin[0])
fin.erase(fin.begin());
else if(lca(fin[0],fin[1])==fin[1])
fin.erase(fin.begin()+1);
for(int i=0;i<fin.size();i++)
ans+=(dep[fin[i]]-dep[l3]);
ans++;
printf("%d\n",ans);
}
return 0;
}