一类换根问题

Posted by WHZ0325 on 2023-09-06, Viewed times

昨天训练的时候遇到一道树形 DP 题,做法似乎很经典,来记录一下。

参考 严格鸽的知乎

【CF 633F】The Chocolate Spree

求两条不相交链的最大权值和。

使用换根 DP,枚举将两条链分开的边,两条链分别在这条边所分成的两棵子树中。

可以转化为考虑维护一棵子树中的最长链以及在换根时的转移,记 $f[i]$ 为到结点 $i$ 的链的最大权值和,则以 $i$ 为根的子树中的答案 $ans[i]$ 为经过 $i$ 的最长链(由最长和次长两个子节点中的链拼成)和 $i$ 子树中的最长链 $ans[i]$。

换根的过程中需要转移信息,就用一个多重集合来保存用于计算最大权值和的若干个候选数,在更换根节点的过程中维护。

在所枚举边两侧的子树都独立时统计答案,时间复杂度为 $O(nlog_2n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <cstdio>
#include <vector>
#include <set>
#define N 100005
typedef long long int ll;

char buf[1 << 16], *fs = buf, *ft = buf;
inline char gc() {
if(fs == ft) {
ft = (fs = buf) + fread(buf, 1, 1 << 16, stdin);
if(fs == ft) return 0;
}
return *fs++;
}
inline int read() {
int num = 0, f = 0; char c = gc();
while(c < '0' || '9' < c) {
if(c == '-') f = 1;
c = gc();
}
while('0' <= c && c <= '9') {
num = (num << 3) + (num << 1) + c - '0';
c = gc();
}
return f ? -num : num;
}

int a[N]; std::vector<int> g[N];

struct data {
std::multiset<ll> ms;
inline void add(ll x) { ms.insert(x); }
inline void del(ll x) { ms.erase(ms.find(x)); }
inline ll max() {
if(ms.size() < 1) return 0;
auto it = ms.end(); --it; return *it;
}
inline ll nax() {
if(ms.size() < 2) return 0;
auto it = ms.end(); --it; --it; return *it;
}
} f[N], ans[N];

void dfs(int x, int fa) {
f[x].add(a[x]);
for(int i = 0, ed = g[x].size(); i < ed; ++i) {
int v = g[x][i]; if(v == fa) continue;
dfs(v, x);
f[x].add(f[v].max() + a[x]);
ans[x].add(ans[v].max());
}
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
}

ll res = 0;
void dp(int x, int fa) {
for(int i = 0, ed = g[x].size(); i < ed; ++i) {
int v = g[x][i]; if(v == fa) continue;

/* Remove v from x */
ans[x].del(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
f[x].del(f[v].max() + a[x]);
ans[x].del(ans[v].max());
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());

/* Calculate answer between x and v */
res = std::max(res, ans[x].max() + ans[v].max());

/* Add x to v */
ans[v].del(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());
f[v].add(f[x].max() + a[v]);
ans[v].add(ans[x].max());
ans[v].add(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());

dp(v, x);

/* Remove x from v */
ans[v].del(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());
ans[v].del(ans[x].max());
f[v].del(f[x].max() + a[v]);
ans[v].add(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());

/* Add v to x */
ans[x].del(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
ans[x].add(ans[v].max());
f[x].add(f[v].max() + a[x]);
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
}
}

int main() {
int n = read();
for(int i = 1; i <= n; ++i) a[i] = read();
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
g[u].emplace_back(v); g[v].emplace_back(u);
}
dfs(1, 0); dp(1, 0); printf("%lld\n", res);
return 0;
}

【CCPC 2022 桂林】Group Homework

求两条链的最大权值和,相交部分不被计算。

不难发现对于两条链,若重合部分不止一个点,那么可以去掉重合部分,加上两个端点,得到结果更大的两条链。因此答案转变为重合部分有一个点(它不被计算)时由它延伸出四条链的最大权值和及任意两条不相交链的最大权值和,在上述代码基础上稍作修改即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <cstdio>
#include <vector>
#include <set>
#define N 200005

char buf[1 << 16], *fs = buf, *ft = buf;
inline char gc() {
if(fs == ft) {
ft = (fs = buf) + fread(buf, 1, 1 << 16, stdin);
if(fs == ft) return 0;
}
return *fs++;
}
inline int read() {
int num = 0, f = 0; char c = gc();
while(c < '0' || '9' < c) {
if(c == '-') f = 1;
c = gc();
}
while('0' <= c && c <= '9') {
num = (num << 3) + (num << 1) + c - '0';
c = gc();
}
return f ? -num : num;
}

struct data {
std::multiset<int> ms;
inline void add(int x) { ms.insert(x); }
inline void del(int x) { ms.erase(ms.find(x)); }
inline int max() {
if(ms.size() < 1) return 0;
auto it = ms.end(); --it; return *it;
}
inline int nax() {
if(ms.size() < 2) return 0;
auto it = ms.end(); --it; --it; return *it;
}
inline int four(int ax) {
if(ms.size() < 4) return 0;
int res = 0; auto it = ms.end();
int cx = 4; while(cx--) { --it; res += (*it) - ax; }
return res;
}
} f[N], ans[N];

int a[N]; std::vector<int> g[N];
void dfs(int x, int fa) {
f[x].add(a[x]);
for(int i = 0, ed = g[x].size(); i < ed; ++i) {
int v = g[x][i]; if(v == fa) continue;
dfs(v, x);
f[x].add(f[v].max() + a[x]);
ans[x].add(ans[v].max());
}
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
}

int res = 0;
void dp(int x, int fa) {
res = std::max(res, f[x].four(a[x]));
for(int i = 0, ed = g[x].size(); i < ed; ++i) {
int v = g[x][i]; if(v == fa) continue;

/* Remove v from x */
ans[x].del(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
f[x].del(f[v].max() + a[x]);
ans[x].del(ans[v].max());
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());

/* Calculate answer between x and v */
res = std::max(res, ans[x].max() + ans[v].max());

/* Add x to v */
ans[v].del(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());
f[v].add(f[x].max() + a[v]);
ans[v].add(ans[x].max());
ans[v].add(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());

dp(v, x);

/* Remove x from v */
ans[v].del(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());
ans[v].del(ans[x].max());
f[v].del(f[x].max() + a[v]);
ans[v].add(f[v].nax() ? f[v].max() + f[v].nax() - a[v] : f[v].max());

/* Add v to x */
ans[x].del(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
ans[x].add(ans[v].max());
f[x].add(f[v].max() + a[x]);
ans[x].add(f[x].nax() ? f[x].max() + f[x].nax() - a[x] : f[x].max());
}
}

int main() {
int n = read();
for(int i = 1; i <= n; ++i) a[i] = read();
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
g[u].emplace_back(v); g[v].emplace_back(u);
}
if(n == 1) puts("0");
else {
dfs(1, 0); dp(1, 0); printf("%d\n", res);
}
return 0;
}

注意事项

multiset.erase(x) 会删去所有值为 $x$ 的元素,而 $multiset.erase(multiset.find(x))$ 才能够仅删除一个值为 $x$ 的元素。