Subway Lines(树上两条路的交点数)

原题: https://cn.vjudge.net/problem/Gym-101908L

题意: 给出一棵树,n节点,每次询问给两对叶子,求这两对叶子产生路径的交集

解析:

找被走过两次的点->走被走过两次的所有Lca,Lca所构成的那一段长度就是点的数量
Subway Lines(树上两条路的交点数)

显然,目标线段的端点一定是这些叶子节点的某个Lca


步骤:

  1. 找到所有Lca,放入set
  2. 统计哪些Lca被走过两次
  3. 怎么判断走过几次:一对叶子(x1,x2)与它的节点l1,假设某个点在x1到l1之间或x2到l2之间,则在路径上
  4. 怎么判断在x1和l1之间:Lca(x1,p)==p&&Lca(p,l1)==l1
  5. 对于走过两次的点,求出组成路径的长度
  6. 怎么求路径:首先是所有点的Lca,作为Root,删除除了最底下点以外的所有点
  7. 怎么删除:if(Lca(p1,p2)==p2)Erase(p2)
#include<bits/stdc++.h>
using namespace std;
#define maxn 500005
int head[maxn<<1],fa[maxn][25],vis[maxn],cnt,dep[maxn],dis[maxn];
struct node
{
    int to,next,wei;
}e[maxn<<1];
void add(int x,int y)
{
    e[cnt].to=y;
    e[cnt].next=head[x];
    head[x]=cnt++;
}
void bfs()
{
    fa[1][0]=1;
    dep[1]=0;
    dis[1]=0;
    queue<int>Q;
    Q.push(1);
    while(!Q.empty())
    {
        int u,v;
        u=Q.front();
        Q.pop();
        for(int i=1;i<=16;i++)
            fa[u][i]=fa[fa[u][i-1]][i-1];
        for(int i=head[u];~i;i=e[i].next)
        {
            v=e[i].to;
            if(v==fa[u][0])
                continue;
            dis[v]=dis[u]+e[i].wei;
            dep[v]=dep[u]+1;
            fa[v][0]=u;
            Q.push(v);
        }
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    for(int i=16;i>=0;i--)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
    if(x==y)return x; 
    for(int i=16;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0]; 
}
set<int>st;
map<int,int>mp;
vector<int>fin;
int main()
{
    int n,q;
    scanf("%d%d",&n,&q);
    int l,r;
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
        scanf("%d%d",&l,&r),add(l,r),add(r,l);
    bfs();
    while(q--)
    {
        st.clear();
        fin.clear();
        mp.clear();
        int a,b,c,d;
        scanf("%d%d%d%d",&a,&b,&c,&d);
        st.insert(lca(a,b)),st.insert(lca(a,c)),st.insert(lca(a,d));
        st.insert(lca(b,c)),st.insert(lca(b,d)); st.insert(lca(c,d));
        int l1=lca(a,b),l2=lca(c,d);
        for(set<int>::iterator it=st.begin();it!=st.end();it++)
        {
            int p=*it;
            if((lca(a,p)==p&&lca(p,l1)==l1)||(lca(b,p)==p&&lca(p,l1)==l1))
                mp[p]++;
        }
        for(set<int>::iterator it=st.begin();it!=st.end();it++)
        {
            int p=*it;
            if((lca(c,p)==p&&lca(p,l2)==l2)||(lca(d,p)==p&&lca(p,l2)==l2))
                mp[p]++;
        }
        int ans=0;
        for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
        {
            if(it->second>=2)
                fin.push_back(it->first);
        }
        if(fin.size()==0)
        {
            printf("0\n");
            continue;
        }
        if(fin.size()==1)
        {
            printf("1\n");
            continue;
        }
        int l3,deep=1e9;
        for(int i=0;i<fin.size();i++)
        {
            if(dep[fin[i]]<deep)
                l3=fin[i],deep=dep[fin[i]];
        }
        while(fin.size()>2)
        {
            for(int i=0;i<fin.size()-1;i++)
            {
                for(int j=1;j<fin.size();j++)
                {
                    if(lca(fin[i],fin[j])==fin[i])
                        fin.erase(fin.begin()+i);
                    else if(lca(fin[i],fin[j])==fin[j])
                        fin.erase(fin.begin()+j);
                }
            }
        }
        if(lca(fin[0],fin[1])==fin[0])
            fin.erase(fin.begin());
        else if(lca(fin[0],fin[1])==fin[1])
            fin.erase(fin.begin()+1);
        for(int i=0;i<fin.size();i++)
            ans+=(dep[fin[i]]-dep[l3]);
        ans++;
        printf("%d\n",ans);
    }
    return 0;
}