BZOJ 1588: [HNOI2002]营业额统计

描述

题面比较长,简化一下就是这样:给出一个长为 $n$ 序列,对于每一个数 $a_{i}$,找出之前与它相差最小的数,两者相减取绝对值加入答案。

序列长度 $n\leq 32767, a_{i} \leq 10^{6}$

分析

我们可以用 Splay 来实现,时间复杂度 $\Theta(nlog_{2}n)$ 。

这里考虑如何用更基础的数据结构来实现。

将 $n$ 个元素从小到大排序得到序列 $c[]$ ,用 $rank[]$ 数组记录原来的第 $i$ 个元素在排序之后的位置。

将排序后的序列在数组中建成一个双向链表。然后从 $n \rightarrow 1$ 依次计算贡献。

对于 $a[i]$ ,查看 $rank[i]$ 的前驱 $pre[rank[i]]$ 和其后继 $nxt[rank[i]]$ 所指向的数。

最小的波动必然是 $a[i]$ 与 $a[pre[i]]$ 之间或是 $a[i]$ 与 $a[nxt[i]]$ 之间的绝对值。

处理完 $a[i]$ 之后我们把 $a[i]$ 从双向链表中删除,然后处理 $a[i - 1]$ 。

这么说,可能不是很容易懂。举一个例子来看看吧:

- 1 2 3 4
$a[]$ 9 3 7 5
$b[]$ 3 5 7 9
- 1 2 3 4
$rank[]$ 2 4 3 1

从 $5$ 开始 且 $i = 4$ ,$pre[i] = 2, nxt[i] = 3$ ,那么这个点的贡献 $Add_{i} = min(\left | a_{4} - a_{2} \right |,\left | a_{4} - a_{3} \right | ) = 2$ 。在序列 $b[]$ 中删除 $a_{4}$ 。

然后是 $7$ 且$i = 3$ ,$pre[i] = 2, nxt[i] = 1$ ,那么这个点的贡献为 $Add_{i} = min(\left | a_{3} - a_{2} \right |,\left | a_{1} - a_{3} \right | ) = 2$ ,在序列 $b[]$ 中删除 $a_{3}$ 。

恩,就酱。时间复杂度依然是 $\Theta(nlog_{2}n)$ 。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
//  Created by ZJYelizaveta on Thursday, October 26, 2017 PM04:31:33 CST
// Copyright (c) 2017年 ZJYelizaveta. All rights reserved.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
template<typename T> T readIn() {
T x(0), f(1);
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
const int MAX_N = 32767 + 3;
const int INF = 0x3f3f3f3f;
int n;
int val[MAX_N];
int pre[MAX_N], nxt[MAX_N], rank[MAX_N];

inline bool cmp(int i, int j) {
return val[i] < val[j];
}

int main()
{
n = readIn<int>();
for (int i = 1; i <= n; ++i) {
val[i] = readIn<int>();
rank[i] = i;
}
sort(rank + 1, rank + n + 1, cmp);

/*
for (int i = 1; i <= n; ++i) printf("%d ", i); printf("\n");
for (int i = 1; i <= n; ++i) printf("%d ", val[i]); printf("\n");
for (int i = 1; i <= n; ++i) printf("%d ", val[rank[i]]); printf("\n");
for (int i = 1; i <= n; ++i) printf("%d ", rank[i]); printf("\n");
*/

for (int i = 1; i <= n; ++i) {
pre[rank[i]] = rank[i - 1];
nxt[rank[i]] = rank[i + 1];
}

int ans = val[1];
for (int i = n; i >= 2; --i) {
int l = INT_MAX, r = INT_MAX;

if (pre[i] != 0) l = val[i] - val[pre[i]];
if (nxt[i] != 0) r = val[nxt[i]] - val[i];

ans += min(l, r);

nxt[pre[i]] = nxt[i];
pre[nxt[i]] = pre[i];
}

printf("%d\n", ans);

return 0;
}
Share