【CSP 2022】聚集方差

Posted by WHZ0325 on 2023-03-12, Viewed times

题目描述

给你一棵 $n$ 个节点的树,记以节点 $x$ 为根的子树组成集合为 $T(x)$,求 $\sum_{y\in T(x)}\min_{z\in T(x),z\neq y}(a_z-a_y)^2$。

$2\le n\le 3\times 10^5$,$0\le a_i\le 10^9$。

算法分析

思维难度其实不大,对每个节点开一个集合,树上启发式合并就可以了。

怎么合并呢?不难发现插入新元素时会影响到的只有与它距离最近的两个节点的贡献,动态维护即可。

时间复杂度大概是 $O(nlog_2^2n)$。

题目的坑点在于卡常。

第一次是用 multiset 实现的,直接卡成 40 分的暴力,后来改成 map 变成 55 分($a_i$ 值相同的点合并为一个),又优化了一下,把所有二分操作合并为一个才 A 掉。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <map>
#include <vector>
#include <algorithm>
#define N 300005
typedef long long int ll;
char buf[1<<16], *fs = buf, *ft = buf;
inline char gc() {
if(fs == ft) ft = (fs = buf) + fread(buf, 1, 1<<16, stdin);
return *fs++;
}
char c;
inline void read(int &num) {
num = 0; c = gc();
while('0' <= c && c <= '9') {
num = (num << 3) + (num << 1) + c - '0';
c = gc();
}
}
int n, a[N], id[N]; ll res[N], ans[N];
std::map<int, int> s[N];
std::vector<int> son[N];
inline int getDist(int o, std::map<int, int>::iterator it) {
if(s[o].size() == 1 || it->second > 1) return 0;
int ans = INT_MAX, x = it->first;
if(it != s[o].begin()) {
--it;
ans = std::min(ans, abs(x - it->first));
++it;
}
if((++it) != s[o].end()) ans = std::min(ans, abs(x - it->first));
return ans;
}
inline int merge(int x, int y) {
if(s[y].size() > s[x].size()) std::swap(x, y);
std::map<int, int>::iterator it;
for(it = s[y].begin(); it != s[y].end(); ++it) {
std::map<int, int>::iterator it0 = s[x].lower_bound(it->first);
if(it0 != s[x].end() && it0->first == it->first) {
int oldDist = getDist(x, it0);
res[x] -= (ll)oldDist * oldDist;
s[x][it->first] += it->second;
continue;
}
int nearest = it->second > 1 ? 0 : INT_MAX;
if(it0 != s[x].end()) {
int newDist = abs(it->first - it0->first);
if(s[x][it0->first] == 1) {
int oldDist = getDist(x, it0);
if(s[x].size() == 1 || oldDist > newDist) {
res[x] -= (ll)oldDist * oldDist;
res[x] += (ll)newDist * newDist;
}
}
nearest = std::min(nearest, newDist);
}
if(it0 != s[x].begin()) {
--it0;
int newDist = abs(it->first - it0->first);
if(s[x][it0->first] == 1) {
int oldDist = getDist(x, it0);
if(s[x].size() == 1 || oldDist > newDist) {
res[x] -= (ll)oldDist * oldDist;
res[x] += (ll)newDist * newDist;
}
}
nearest = std::min(nearest, newDist);
}
res[x] += (ll)nearest * nearest;
s[x][it->first] += it->second;
}
return x;
}
void calc(int x) {
ans[x] = 0; ++s[x][a[x]]; id[x] = x;
for(int i = 0, end = son[x].size(); i < end; ++i) {
int v = son[x][i];
calc(v);
id[x] = merge(id[x], id[v]);
}
ans[x] = res[id[x]];
}
int main() {
read(n);
for(int i = 2; i <= n; ++i) {
int p; read(p);
son[p].push_back(i);
}
for(int i = 1; i <= n; ++i) read(a[i]);
calc(1);
for(int i = 1; i <= n; ++i) printf("%lld\n", ans[i]);
return 0;
}