维护子树内有多少关键点,最长链和最短链进行转移
代码:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; #define N 1000050 #define inf 0x3f3f3f3f int head[N],to[N<<1],nxt[N<<1],cnt,n,m,fa[N],top[N],dep[N],son[N],siz[N]; int a[N],la,dfn[N],S[N],tp,vis[N]; char buf[1000000],*p1,*p2; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++) int rd() { int x=0; char s=nc(); while(s<'0'||s>'9') s=nc(); while(s>='0'&&s<='9') x=(((x<<2)+x)<<1)+s-'0',s=nc(); return x; } inline void add(int u,int v) { to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt; } void df1(int x,int y) { int i; siz[x]=1; dfn[x]=++dfn[0]; fa[x]=y; dep[x]=dep[y]+1; for(i=head[x];i;i=nxt[i]) if(to[i]!=y) { df1(to[i],x); siz[x]+=siz[to[i]]; if(siz[to[i]]>siz[son[x]]) son[x]=to[i]; } } void df2(int x,int t) { int i; top[x]=t; if(son[x]) df2(son[x],t); for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]); } ll ans1,sz[N]; int f[N],g[N],ans2,ans3; void df3(int x) { int i; if(vis[x]) { sz[x]=1; f[x]=g[x]=0; }else { sz[x]=0; f[x]=inf; g[x]=-inf; } for(i=head[x];i;i=nxt[i]) { df3(to[i]); int len=dep[to[i]]-dep[x]; ans1+=ll(len)*sz[to[i]]*(la-sz[to[i]]); ans2=min(ans2,len+f[to[i]]+f[x]); ans3=max(ans3,len+g[to[i]]+g[x]); f[x]=min(f[x],f[to[i]]+len); g[x]=max(g[x],g[to[i]]+len); sz[x]+=sz[to[i]]; } head[x]=0; } int lca(int x,int y) { for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y); return dep[x]<dep[y]?x:y; } inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];} int main() { n=rd(); int i,x,y,l; for(i=1;i<n;i++) { x=rd(); y=rd(); add(x,y); add(y,x); } df1(1,0); df2(1,1); m=rd(); memset(head,0,sizeof(head)); cnt=0; while(m--) { cnt=0; la=rd(); for(i=1;i<=la;i++) a[i]=rd(),vis[a[i]]=1; sort(a+1,a+la+1,cmp); tp=0; S[++tp]=1; for(i=1;i<=la;i++) { x=a[i],l=lca(x,S[tp]); while(dep[l]<dep[S[tp]]) { if(dep[l]>=dep[S[tp-1]]) { add(l,S[tp]); tp--; if(S[tp]!=l) S[++tp]=l; break; } add(S[tp-1],S[tp]); tp--; } if(S[tp]!=x) S[++tp]=x; } while(tp>1) add(S[tp-1],S[tp]),tp--; ans1=0; ans3=0; ans2=1ll<<30; df3(1); printf("%lld %d %d\n",ans1,ans2,ans3); for(i=1;i<=la;i++) vis[a[i]]=0; } }题目链接
分析:
建出虚树,对于虚树上的一条边,考虑边上的一个非关键点。
这个点子树内不在虚树上的点和这个点被相同的点管理。
先求出每个虚树上的点被谁管理,然后对于虚树上的一条边,倍增找一个分界点。
容斥一下,将每个点的答案设成虚树上被这个点管理的深度最小的点的子树大小,然后分析每条边时减掉不合法的。
代码:
#include <cstdio> #include <cstring> #include <algorithm> #include <cstdlib> using namespace std; #define N 300050 int head[N],to[N<<1],nxt[N<<1],cnt,n,m,a[N],vis[N],la; int fa[N],dep[N],son[N],siz[N],top[N],dfn[N],f[20][N],b[N],ans[N]; int S[N],tp,wv[N]; inline void add(int u,int v) { to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt; } void df1(int x,int y) { int i; siz[x]=1; fa[x]=y; dep[x]=dep[y]+1; dfn[x]=++dfn[0]; for(i=head[x];i;i=nxt[i]) if(to[i]!=y) { df1(to[i],x); siz[x]+=siz[to[i]]; if(siz[to[i]]>siz[son[x]]) son[x]=to[i]; } } void df2(int x,int t) { int i; top[x]=t; if(son[x]) df2(son[x],t); for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]); } int lca(int x,int y) { for(;top[x]!=top[y];y=fa[top[y]]) if(dep[top[x]]>dep[top[y]]) swap(x,y); return dep[x]<dep[y]?x:y; } inline bool cmp(const int &x,const int &y) {return dfn[x]<dfn[y];} inline bool cmp2(int x,int y) { return dep[x]==dep[y]?x < y:dep[x]<dep[y]; } int dis(int x,int y) { int l=lca(x,y); return dep[x]+dep[y]-2*dep[l]; } inline bool cmp3(int x,int y,int z) { int l1=dis(x,z),l2=dis(y,z); return l1 == l2 ? x < y : l1 < l2; } void df3(int x) { int i; wv[x]=0; for(i=head[x];i;i=nxt[i]) { df3(to[i]); if(wv[to[i]] && (!wv[x]||cmp2(wv[to[i]],wv[x]))) { wv[x]=wv[to[i]]; } } if(vis[x]) { wv[x]=x; } } void df4(int x,int y) { int i; if(wv[y] && (!wv[x] || cmp3(wv[y],wv[x],x))) wv[x]=wv[y]; //ans[wv[x]]+=siz[x]; for(i=head[x];i;i=nxt[i]) { df4(to[i],x); } ans[wv[x]]=siz[x]; } int jmp(int x,int y) { int i; for(i=19;i>=0;i--) { if(f[i][x]&&dep[f[i][x]]>dep[y]) x=f[i][x]; } return x; } int jmp_half(int x,int y) { int t=x; int i; for(i=19;i>=0;i--) { if(f[i][x]&&cmp3(wv[t],wv[y],f[i][x])) x=f[i][x]; } return x; } void df5(int x) { int i; for(i=head[x];i;i=nxt[i]) { if(wv[x]==wv[to[i]]) { }else { int h=jmp_half(to[i],x); ans[wv[x]] -= siz[h]; ans[wv[to[i]]] += siz[h] - siz[to[i]]; } df5(to[i]); } head[x]=0; } void prt(int x) { int i; printf("x=%d wv=%d\n",x,wv[x]); for(i=head[x];i;i=nxt[i]) { prt(to[i]); } } int main() { scanf("%d",&n); int i,x,y,j; for(i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x); df1(1,0); df2(1,1); for(i=1;i<=n;i++) f[0][i]=fa[i]; for(i=1;(1<<i)<=n;i++) for(j=1;j<=n;j++) f[i][j]=f[i-1][f[i-1][j]]; memset(head,0,sizeof(head)); cnt=0; scanf("%d",&m); while(m--) { scanf("%d",&la); cnt=0; for(i=1;i<=la;i++) scanf("%d",&a[i]),b[i]=a[i],ans[a[i]]=0,vis[a[i]]=1; sort(a+1,a+la+1,cmp); S[tp=1]=1; for(i=1;i<=la;i++) { x=a[i],y=lca(x,S[tp]); while(dep[y]<dep[S[tp]]) { if(dep[y]>=dep[S[tp-1]]) { add(y,S[tp]); tp--; if(S[tp]!=y) S[++tp]=y; break; } add(S[tp-1],S[tp]); tp--; } if(S[tp]!=x) S[++tp]=x; } while(tp>1) add(S[tp-1],S[tp]),tp--; df3(1); df4(1,vis[1] ? 1 : 0); //prt(1); df5(1); for(i=1;i<=la;i++) printf("%d ",ans[b[i]]); puts(""); for(i=1;i<=la;i++) vis[a[i]]=0; } } /* 10 2 1 3 2 4 3 5 4 6 1 7 3 8 3 9 4 10 1 1 5 2 7 3 6 9 */题目链接
重写一遍,贴代码。