虚树

题目链接

分析:

  • 维护子树内有多少关键点,最长链和最短链进行转移

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define N 1000050
#define inf 0x3f3f3f3f
int head[N],to[N<<1],nxt[N<<1],cnt,n,m,fa[N],top[N],dep[N],son[N],siz[N];
int a[N],la,dfn[N],S[N],tp,vis[N];
char buf[1000000],*p1,*p2;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++)
int rd() {
    int x=0; char s=nc();
    while(s<'0'||s>'9') s=nc();
    while(s>='0'&&s<='9') x=(((x<<2)+x)<<1)+s-'0',s=nc();
    return x;
}
inline void add(int u,int v) {
    to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void df1(int x,int y) {
    int i; siz[x]=1; dfn[x]=++dfn[0];
    fa[x]=y; dep[x]=dep[y]+1;
    for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
        df1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
    }
}
void df2(int x,int t) {
    int i; top[x]=t;
    if(son[x]) df2(son[x],t);
    for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
ll ans1,sz[N];
int f[N],g[N],ans2,ans3;
void df3(int x) {
    int i;
    if(vis[x]) {
        sz[x]=1;
        f[x]=g[x]=0;
    }else {
        sz[x]=0;
        f[x]=inf;
        g[x]=-inf;
    }
    for(i=head[x];i;i=nxt[i]) {
        df3(to[i]);
        int len=dep[to[i]]-dep[x];
        ans1+=ll(len)*sz[to[i]]*(la-sz[to[i]]);
        ans2=min(ans2,len+f[to[i]]+f[x]);
        ans3=max(ans3,len+g[to[i]]+g[x]);
        f[x]=min(f[x],f[to[i]]+len);
        g[x]=max(g[x],g[to[i]]+len);
        sz[x]+=sz[to[i]];
    }
    head[x]=0;
}
int lca(int x,int y) {
    for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y);
    return dep[x]<dep[y]?x:y;
}
inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];}
int main() {
    n=rd();
    int i,x,y,l;
    for(i=1;i<n;i++) {
        x=rd(); y=rd();
        add(x,y); add(y,x);
    }
    df1(1,0); df2(1,1);
    m=rd();
    memset(head,0,sizeof(head)); cnt=0;
    while(m--) {
        cnt=0;
        la=rd();
        for(i=1;i<=la;i++) a[i]=rd(),vis[a[i]]=1;
        sort(a+1,a+la+1,cmp);
        tp=0;
        S[++tp]=1;
        for(i=1;i<=la;i++) {
            x=a[i],l=lca(x,S[tp]);
            while(dep[l]<dep[S[tp]]) {
                if(dep[l]>=dep[S[tp-1]]) {
                    add(l,S[tp]); tp--;
                    if(S[tp]!=l) S[++tp]=l;
                    break;
                }
                add(S[tp-1],S[tp]); tp--;
            }
            if(S[tp]!=x) S[++tp]=x;
        }
        while(tp>1) add(S[tp-1],S[tp]),tp--;
        ans1=0; ans3=0; ans2=1ll<<30;
        df3(1);
        printf("%lld %d %d\n",ans1,ans2,ans3);
        for(i=1;i<=la;i++) vis[a[i]]=0;
    }
}

题目链接

分析:

  • 建出虚树,对于虚树上的一条边,考虑边上的一个非关键点。
  • 这个点子树内不在虚树上的点和这个点被相同的点管理。
  • 先求出每个虚树上的点被谁管理,然后对于虚树上的一条边,倍增找一个分界点。
  • 容斥一下,将每个点的答案设成虚树上被这个点管理的深度最小的点的子树大小,然后分析每条边时减掉不合法的。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define N 300050
int head[N],to[N<<1],nxt[N<<1],cnt,n,m,a[N],vis[N],la;
int fa[N],dep[N],son[N],siz[N],top[N],dfn[N],f[20][N],b[N],ans[N];
int S[N],tp,wv[N];
inline void add(int u,int v) {
    to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void df1(int x,int y) {
    int i;
    siz[x]=1; fa[x]=y; dep[x]=dep[y]+1; dfn[x]=++dfn[0];
    for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
        df1(to[i],x); siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
    }
}
void df2(int x,int t) {
    int i;
    top[x]=t;
    if(son[x]) df2(son[x],t);
    for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
int lca(int x,int y) {
    for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y);
    return dep[x]<dep[y]?x:y;
}
inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];}
inline bool cmp2(int x,int y) {
    return dep[x]==dep[y]?x < y:dep[x]<dep[y];
}
int dis(int x,int y) {
    int l=lca(x,y);
    return dep[x]+dep[y]-2*dep[l];
}
inline bool cmp3(int x,int y,int z) {
    int l1=dis(x,z),l2=dis(y,z);
    return l1 == l2 ? x < y : l1 < l2;
}
void df3(int x) {
    int i;
    wv[x]=0;
    for(i=head[x];i;i=nxt[i]) {
        df3(to[i]);
        if(wv[to[i]] && (!wv[x]||cmp2(wv[to[i]],wv[x]))) {
            wv[x]=wv[to[i]];
        }
    }
    if(vis[x]) {
        wv[x]=x;
    }
}
void df4(int x,int y) {
    int i;
    if(wv[y] && (!wv[x] || cmp3(wv[y],wv[x],x))) wv[x]=wv[y];
    //ans[wv[x]]+=siz[x];
    for(i=head[x];i;i=nxt[i]) {
        df4(to[i],x);
    }
    ans[wv[x]]=siz[x];
}
int jmp(int x,int y) {
    int i;
    for(i=19;i>=0;i--) {
        if(f[i][x]&&dep[f[i][x]]>dep[y]) x=f[i][x];
    }
    return x;
}
int jmp_half(int x,int y) {
    int t=x;
    int i;
    for(i=19;i>=0;i--) {
        if(f[i][x]&&cmp3(wv[t],wv[y],f[i][x])) x=f[i][x];
    }
    return x;
}
void df5(int x) {
    int i;
    for(i=head[x];i;i=nxt[i]) {
        if(wv[x]==wv[to[i]]) {

        }else {
            int h=jmp_half(to[i],x);
            ans[wv[x]] -= siz[h];
            ans[wv[to[i]]] += siz[h] - siz[to[i]];
        }
        df5(to[i]);
    }

    head[x]=0;
}
void prt(int x) {
    int i;
    printf("x=%d wv=%d\n",x,wv[x]);
    for(i=head[x];i;i=nxt[i]) {
        prt(to[i]);
    }
}
int main() {
    scanf("%d",&n);
    int i,x,y,j;
    for(i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x);
    df1(1,0); df2(1,1);
    for(i=1;i<=n;i++) f[0][i]=fa[i];
    for(i=1;(1<<i)<=n;i++) for(j=1;j<=n;j++) f[i][j]=f[i-1][f[i-1][j]];
    memset(head,0,sizeof(head)); cnt=0;
    scanf("%d",&m);
    while(m--) {
        scanf("%d",&la); cnt=0;
        for(i=1;i<=la;i++) scanf("%d",&a[i]),b[i]=a[i],ans[a[i]]=0,vis[a[i]]=1;
        sort(a+1,a+la+1,cmp);
        S[tp=1]=1; 
        for(i=1;i<=la;i++) {
            x=a[i],y=lca(x,S[tp]);
            while(dep[y]<dep[S[tp]]) {
                if(dep[y]>=dep[S[tp-1]]) {
                    add(y,S[tp]); tp--;
                    if(S[tp]!=y) S[++tp]=y;
                    break;
                }
                add(S[tp-1],S[tp]); tp--;
            }
            if(S[tp]!=x) S[++tp]=x;
        }
        while(tp>1) add(S[tp-1],S[tp]),tp--;
        df3(1); df4(1,vis[1] ? 1 : 0);
        //prt(1);
        
        df5(1);
        for(i=1;i<=la;i++) printf("%d ",ans[b[i]]); 
        puts("");
        for(i=1;i<=la;i++) vis[a[i]]=0;
    }
}
/*
10
2 1
3 2
4 3
5 4
6 1
7 3
8 3
9 4
10 1
1
5
2 7 3 6 9
*/

题目链接

重写一遍,贴代码。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 250050
typedef long long ll;
const ll inf = 1ll<<60;
int head[N],to[N<<1],nxt[N<<1],val[N<<1],cnt,n,m;
int fa[N],top[N],dep[N],son[N],siz[N],a[N],la,vis[N];
int dfn[N],S[N],tp;
ll w[N],f[N];
char buf[1000000],*p1,*p2;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++)
int rd() {
    int x=0; char s=nc();
    while(s<'0'||s>'9') s=nc();
    while(s>='0'&&s<='9') x=(((x<<2)+x)<<1)+s-'0',s=nc();
    return x;
}
inline void add(int u,int v,int w) {
    to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt; val[cnt]=w;
}
inline void Add(int u,int v) {to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;}
inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];}
void df1(int x,int y) {
    int i; siz[x]=1; fa[x]=y; dep[x]=dep[y]+1; dfn[x]=++dfn[0];
    for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
        w[to[i]]=min(w[x],(ll)val[i]); df1(to[i],x); siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
    }
}
void df2(int x,int t) {
    int i; top[x]=t;
    if(son[x]) df2(son[x],t);
    for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
int lca(int x,int y) {
    for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y);
    return dep[x]<dep[y]?x:y;
}
void df3(int x) {
    f[x]=0;
    int i;
    for(i=head[x];i;i=nxt[i]) {
        df3(to[i]);
        f[x]+=min(f[to[i]],w[to[i]]);
    }
    if(vis[x]) f[x]=w[x];
    head[x]=0;
}
int main() {
    n=rd();
    int i,x,y,z;
    for(i=1;i<n;i++) x=rd(),y=rd(),z=rd(),add(x,y,z),add(y,x,z);
    w[1]=inf; df1(1,0); df2(1,1);
    m=rd();
    memset(head,0,sizeof(head));
    while(m--) {
        la=rd();
        cnt=0;
        for(i=1;i<=la;i++) a[i]=rd(),vis[a[i]]=1;
        S[tp=1]=1;
        sort(a+1,a+la+1,cmp);
        for(i=1;i<=la;i++) {
            x=a[i],y=lca(x,S[tp]);
            while(dep[y]<dep[S[tp]]) {
                if(dep[y]>=dep[S[tp-1]]) {
                    Add(y,S[tp]); tp--;
                    if(S[tp]!=y) S[++tp]=y;
                    break;
                }
                Add(S[tp-1],S[tp]); tp--;
            }
            if(S[tp]!=x) S[++tp]=x;
        }
        while(tp>1) Add(S[tp-1],S[tp]),tp--;
        df3(1);
        printf("%lld\n",f[1]);
        for(i=1;i<=la;i++) vis[a[i]]=0;
    }
}

题目链接

分析:

  • 可以发现有用的点都在虚树上,进行树形DP。
  • \(f[x]\)表示全部堵住,\(g[x]\)表示还剩一个上来。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100050
typedef long long ll;
ll f[N],g[N];
int head[N],to[N<<1],nxt[N<<1],cnt,n,m;
int fa[N],dep[N],top[N],son[N],siz[N];
int S[N],a[N],la,tp,vis[N],dfn[N];
inline void add(int u,int v) {
    to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void df1(int x,int y) {
    int i; siz[x]=1; fa[x]=y; dep[x]=dep[y]+1; 
    dfn[x]=++dfn[0];
    for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
        df1(to[i],x); siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
    }
}
void df2(int x,int t) {
    top[x]=t;
    if(son[x]) df2(son[x],t);
    int i;
    for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
int lca(int x,int y) {
    for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y);
    return dep[x]<dep[y]?x:y;
}
inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];}
void df3(int x) {
    int i;
    ll sf=0,smn=0,mx=0,ss=0;
    for(i=head[x];i;i=nxt[i]) {
        df3(to[i]);
        sf+=f[to[i]];
        smn+=min(f[to[i]],g[to[i]]);
        mx=max(mx,f[to[i]]-g[to[i]]);
        if(dep[to[i]]-dep[x]>1) {
            ss+=min(f[to[i]],g[to[i]]+1);
        }else ss+=f[to[i]];
    }
    if(vis[x]) {
        f[x]=int(0x3f3f3f3f);
        g[x]=ss;
    }else {
        f[x]=min(ss,smn+1);
        g[x]=sf-mx;
    }
    head[x]=0;
}
int main() {
    scanf("%d",&n);
    int i,x,y;
    for(i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x);
    df1(1,0); df2(1,1);
    memset(head,0,sizeof(head)); cnt=0;
    scanf("%d",&m);
    while(m--) {
        scanf("%d",&la); cnt=0;
        for(i=1;i<=la;i++) scanf("%d",&a[i]),vis[a[i]]=1;
        int flg=0;
        for(i=1;i<=la;i++) {
            if(vis[fa[a[i]]]) {
                flg=1; break;
            }
        }
        if(flg) {
            puts("-1");
            for(i=1;i<=la;i++) vis[a[i]]=0;
            continue;
        }
        sort(a+1,a+la+1,cmp);
        S[tp=1]=1;
        for(i=1;i<=la;i++) {
            x=a[i],y=lca(x,S[tp]);
            while(dep[y]<dep[S[tp]]) {
                if(dep[y]>=dep[S[tp-1]]) {
                    add(y,S[tp]); tp--;
                    if(S[tp]!=y) S[++tp]=y;
                    break;
                }
                add(S[tp-1],S[tp]); tp--;
            }
            if(S[tp]!=x) S[++tp]=x;
        }
        while(tp>1) add(S[tp-1],S[tp]),tp--;
        df3(1);
        printf("%lld\n",min(f[1],g[1]));
        for(i=1;i<=la;i++) vis[a[i]]=0;
    }
}

题目链接

分析:

  • 答案即圆方树所有子树内外都有关键点的圆点数量。
  • 每次建虚树然后统计每条边的答案即可。
  • 这道题有非虚树做法。

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <vector>
using namespace std;
#define N 400050
#define M 400050
#define mem(x) memset(x,0,sizeof(x))
int head[N],to[M],nxt[M],n,m,cnt=1;
int dfn[N],low[N],vis[M],S[N],tp,bl[N],bcc;
int siz[N],fa[N],dep[N],top[N],son[N];
int a[N],la,dis[N],ans;
char buf[100000],*p1,*p2;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd() {
    int x=0; char s=nc();
    while(s<'0'||s>'9') s=nc();
    while(s>='0'&&s<='9') x=(((x<<2)+x)<<1)+s-'0',s=nc();
    return x;
}
vector<int>V[N];
inline void add(int u,int v) {
    to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void tarjan(int x) {
    int i;
    dfn[x]=low[x]=++dfn[0];
    for(i=head[x];i;i=nxt[i]) if(!vis[i]) {
        vis[i]=vis[i^1]=1; S[++tp]=i;
        if(!dfn[to[i]]) {
            tarjan(to[i]);
            low[x]=min(low[x],low[to[i]]);
            if(low[to[i]]>=dfn[x]) {
                int t=0,u,v; bcc++; V[bcc].clear();
                while(t!=i) {
                    t=S[tp--];
                    u=to[t],v=to[t^1];
                    if(bl[u]!=bcc) bl[u]=bcc,V[bcc].push_back(u);
                    if(bl[v]!=bcc) bl[v]=bcc,V[bcc].push_back(v);
                }
            }
        }else low[x]=min(low[x],dfn[to[i]]);
    }
}
inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];}
void df1(int x,int y) {
    dfn[x]=++dfn[0];
    int i; siz[x]=1; fa[x]=y; dep[x]=dep[y]+1; son[x]=0;
    dis[x]=dis[y]+(x<=n);
    for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
        df1(to[i],x); siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
    }
}
void df2(int x,int t) {
    top[x]=t;
    if(son[x]) df2(son[x],t);
    int i;
    for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
int lca(int x,int y) {
    for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y);
    return dep[x]<dep[y]?x:y;
}
inline void clr() {
    mem(head); mem(dfn); tp=0; cnt=1; mem(vis); mem(bl);
}
void df3(int x) {
    if(x) ans+=(x<=n&&!vis[x]);
    int i;
    for(i=head[x];i;i=nxt[i]) {
        if(x) ans+=dis[fa[to[i]]]-dis[x];
        df3(to[i]);
    }
    head[x]=0;
}
void solve() {
    clr();
    n=rd(),m=rd();
    int i,x,y;
    for(i=1;i<=m;i++) x=rd(),y=rd(),add(x,y),add(y,x);
    bcc=n;
    for(i=1;i<=n;i++) if(!dfn[i]) tarjan(i);
    mem(head); cnt=0;
    for(x=n+1;x<=bcc;x++) {
        int lim=V[x].size(); 
        for(i=0;i<lim;i++) {
            add(x,V[x][i]); add(V[x][i],x);
        }
    }
    mem(dfn); mem(vis);
    df1(1,0); df2(1,1);
    int q;
    q=rd();
    mem(head);
    while(q--) {
        cnt=0;
        la=rd();
        for(i=1;i<=la;i++) a[i]=rd(),vis[a[i]]=1;
        sort(a+1,a+la+1,cmp);
        S[tp=1]=0;
        for(i=1;i<=la;i++) {
            x=a[i],y=lca(a[i],S[tp]);
            while(dep[y]<dep[S[tp]]) {
                if(dep[y]>=dep[S[tp-1]]) {
                    add(y,S[tp]); tp--;
                    if(S[tp]!=y) S[++tp]=y;
                    break;
                }
                add(S[tp-1],S[tp]); tp--;
            }
            if(S[tp]!=x) S[++tp]=x;
        }
        while(tp>1) add(S[tp-1],S[tp]),tp--;
        ans=0; df3(0);
        printf("%d\n",ans);
        for(i=1;i<=la;i++) vis[a[i]]=0;
    }
}
int main() {
    int T;
    T=rd();
    while(T--) solve();
}

版权声明:本文为suika原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.cnblogs.com/suika/p/10014779.html