题解:#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 500000 + 100;
typedef long long ll;
int scan(){
char cc=' ';int re=0,fh=1;while(cc==' '||cc=='\r'||cc=='\n')cc=getchar();
if(cc=='+')cc=getchar(),fh=1;if(cc=='-')cc=getchar(),fh=-1;
while('0'<=cc&&cc<='9'){re=re*10+cc-'0';cc=getchar();}return re*fh;
}
struct Point{
int d;Point *next;
Point(){next = NULL;}
void push(int a){
Point *l = new Point;
l->d = a;
l->next = next;next = l;
}
}point[N];
int n,data[N],list[N],size[N],fa[N];ll f[N];
bool cmp(int a,int b){return f[a] > f[b];}
void dfs(int x = 1){
size[x] = 1;
for(Point *i = point[x].next;i!=NULL;i=i->next)
if(fa[x]!=i->d){
fa[i->d] = x;
dfs(i->d);//f[i->d]--;
size[x] += size[i->d];
}
list[0] = 0;for(Point *i = point[x].next;i!=NULL;i=i->next)if(fa[x]!=i->d)list[++list[0]] = i->d;
sort(list+1,list+1+list[0],cmp);
int tmp = 1;
for(int i=list[0];i;i--){
f[x] = max(f[x],f[list[i]] - tmp);
tmp += size[list[i]]*2;
}
if(x!=1)f[x] = max(f[x],(ll)data[x]-(size[x]-1)*2);
else f[x] = max(f[x],(ll)data[x]);
}
int main(){
//freopen("in.txt","r",stdin);
int i,j,a,b;
n = scan();
for(i=1;i<=n;i++)data[i] = scan();
for(i=1;i<n;i++){
a = scan();b = scan();
point[a].push(b);
point[b].push(a);
}
dfs(1);
printf("%lld",f[1]+(n-1)*2);
return 0;
}
评论