题目描述
在 $[1,n-1]$ 之中的每个位置有一个权值 $H_i$,当经过时权值减一,权值为 $0$ 时不能经过,求最小的跳跃距离(每次移动的距离均不超过跳跃距离),使得能够在位置 $0$ 和位置 $n$ 之间往返 $2x$ 次。
$1\le n\le 10^5$,$1\le x\le 10^9$,$1\le H_i\le 10^4$。
算法分析
一眼二分答案,重点在于如何判断对于给定的跳跃距离 $y$ 是否能够往返 $2x$ 次。
第一种思路是判断每一个长度为 $y$ 的区间中的权值和是否都不低于 $2x$,如果低于则不能够往返,感性证明:
分别证充分性和必要性。对于任意长度为 $y$ 的区间,$2x$ 次往返中的每一次必经过一次该区间内的点,反证若不经过则最小跳跃距离应当超过 $y$,所以区间内的权值和至少应当是 $2x$。对于每个位置,能够跳到的位置必然是它后面长度为 $y$ 的区间中的位置,如果区间中的权值和不低于 $2x$ 则一定能让至少 $2x$ 次跑出去(位置 $n$ 也能够接收到前面 $2x$ 次)。
并不是非常好想,第二种思路则更为自然。
类似网络流的思想,把每个点推给它后面能到达的这 $y$ 个点,推到最后统计能够到达位置 $n$ 的次数是否达到 $2x$ 次(其实就相当于所有次数一齐跳)。推的过程肯定不能一个一个递,用并查集优化,一旦有权值 $H_i$ 满了,不能够再经过了则将它跳过,时间复杂度为 $O(nlog_2n)$。
注意 Java 在并查集 find
函数递归层数过多时会爆栈,因此采用非递归实现。
代码实现
算法一:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| import java.util.Scanner; public class Main { static int n, x, h[]; static boolean check(int y) { for(int i = y; i < n; ++i) if(h[i] - h[i - y] < x) return false; return true; } public static void main(String[] args) { Scanner in = new Scanner(System.in); n = in.nextInt(); x = in.nextInt() << 1; h = new int[n]; h[0] = 0; for(int i = 1; i < n; ++i) h[i] = h[i - 1] + in.nextInt(); int l = 1, r = n; while(l < r) { int mid = (l + r) >> 1; if(check(mid)) r = mid; else l = mid + 1; } System.out.println(l); } }
|
算法二:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import java.util.Scanner; public class Main { static int n, x, h[], contain[], fa[]; static int find(int x) { int rt = x; while(rt != fa[rt]) rt = fa[rt]; while(x != fa[x]) { fa[x] = rt; x = fa[x]; } return rt; } static boolean check(int y) { contain = new int[n]; contain[0] = x; fa = new int[n]; fa[0] = 0; for(int i = 1; i < n; ++i) fa[i] = (h[i] > 0 ? i : find(i - 1)); for(int i = 0; i < n; ++i) if(contain[i] > 0) { int loc = find(Math.min(i + y, n - 1)); while(loc > i && contain[i] > 0) { int flow = Math.min(contain[i], h[loc] - contain[loc]); contain[loc] += flow; contain[i] -= flow; if(h[loc] - contain[loc] == 0) fa[loc] = find(loc - 1); loc = find(loc - 1); } } int sum = 0; for(int i = n - y; i < n; ++i) sum += contain[i]; return sum >= x; } public static void main(String[] args) { Scanner in = new Scanner(System.in); n = in.nextInt(); x = in.nextInt() << 1; h = new int[n+1]; for(int i = 1; i < n; ++i) h[i] = in.nextInt(); int l = 1, r = n; while(l < r) { int mid = (l + r) >> 1; if(check(mid)) r = mid; else l = mid + 1; } System.out.println(l); } }
|