Sol:首先这道题显然点分治,然后利用以下结论,维护路径上的和还有最大值。
\(Lemma\):\(n\) 条边,长度为 \(a_1,a_2…a_n(a_i≤a_{i+1})\) ,其能构成一个 面积大于0的凸多边形 ,当且仅当 \(n>2\) 且 \(\sum_{i=1}^{n-1}a_i\gt 2\times a_n\).
在当前子树中,找出重心后求出其他点到重心的距离、最大值后,在合并跨过重心的答案时,考虑利用树状数组,因为如果两两枚举的话容易T,利用树状数组,将距离离散化后当做树状数组的下标,将最大值排序,依次枚举最大值,(也是在枚举边)在树状数组中找到第一个大于等于\(2*mx-d\)的值,然后查询\(query(m)-query(id[d-1])\),其中\(m\)为离散化后的最大下标,这样为什么可行是因为\(id[d-1]\)后的距离一定大于等于\(d\),那么\(id[d-1]\)后面有几个数,说明就有多少条边可以和当前枚举的边配对,因为\(2*mx-d<x\)等价于\(2*mx<d+x\)。
还有错误,待查
#include <bits/stdc++.h> #define ll long long //using namespace std; constexpr int N=2e5+10; int n,S,MX,root; std::vector<int>E[N]; int sz[N],mxson[N]; ll res; int w[N]; bool vis[N]; void getroot(int u,int fa) { sz[u]=1,mxson[u]=0; for(auto &v:E[u]) { if(vis[v]||v==fa) continue; sz[u]+=sz[v]; mxson[u]=std::max(mxson[u],sz[v]); } mxson[u]=std::max(mxson[u],S-sz[u]); if(mxson[u]<MX) root=u,MX=mxson[u]; } std::vector<std::pair<int,ll>>t; void getdist(int u,ll d,int mx,int fa) { for(auto &v:E[u]) { if(v==fa||vis[v]) continue; t.emplace_back(std::max(mx,w[v]),d+w[v]); getdist(v,d+w[v],std::max(mx,w[v]),u); } } std::vector<ll>dist; int tr[N]; int lowbit(int x) {return x&-x;}; void add(int p,int x,int m) { for(;p<=m;p+=lowbit(p)) tr[p]+=x; } ll query(int p,int m) { ll ans=0; for(;p;p-=lowbit(p)) ans+=tr[p]; return ans; } void solve(int u,int type) { t.clear(),dist.clear(); getdist(u,1ll*w[u],w[u],0); for(auto &[mx,d]:t) dist.emplace_back(d); std::sort(dist.begin(),dist.end()); dist.erase(unique(dist.begin(),dist.end()),dist.end()); int m=dist.size(); std::sort(t.begin(),t.end()); for(auto &[mx,d]:t) { int i=std::lower_bound(dist.begin(),dist.end(),2*mx-d)-dist.begin()+1; if(i<=m) res+=1ll*type*(query(m,m)-query(i-1,m)); add(i,1,m); } for(auto &[mx,d]:t) { int i=std::lower_bound(dist.begin(),dist.end(),d)-dist.begin()+1; add(i,-1,m); } } void Divide(int u) { solve(u,1); vis[u]=1; for(auto &v:E[u]) { if(vis[v]) continue; solve(v,-1);//消除子树重复计算的答案 S=sz[v],root=0,MX=N; getroot(v,0); Divide(v); } } int main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); int T; std::cin>>T; while(T--) { int n; std::cin>>n; for(int i=1;i<=n;i++) std::cin>>(w[i]),E[i].clear(),vis[i]=0; for(int i=1;i<n;i++) { int u,v; std::cin>>u>>v; E[u].emplace_back(v); E[v].emplace_back(u); } res=0; S=n,MX=N,root=0; getroot(1,0); Divide(root); std::cout<<res<<'\n'; } }