BZOJ 1036: [ZJOI2008]树的统计Count

描述

一棵树上有 $n$个 节点,编号分别为 $1$ 到 $n$ ,每个节点都有一个权值 $w$ 。我们将以下面的形式来要求你对这棵树完成
一些操作:

  • $\textrm{CHANGE} \ u \ t$ : 把结点 $u$ 的权值改为 $t$
  • $\textrm{QMAX} \ u \ v$ :询问从点 $u$ 到点 $v$ 的路径上的节点的最大权值
  • $\textrm{QSUM} \ u \ v$ :询问从点 $u$ 到点 $v$ 的路径上的节点的权值和 注意:从点 $u$ 到点 $v$ 的路径上的节点包括 $u$ 和 $v$ 本身。

$1 \leq n \leq 30000, 0 \leq q \leq 200000, -30000 \leq w \leq 30000$

分析

马上 NOIP 2017 了,就当顺便复习一下树链剖分和线段树吧 qwq

树链剖分

树链剖分用一句话概括就是:把一棵树剖分为若干条链,然后利用数据结构(树状数组,Splay,线段树等等)去维护每一条链,复杂度为 $\Theta(logn)$ 。

这里所说的树链剖分是轻重链剖分,不是长链剖分,表示蒟蒻不会长链剖分,轻重链剖分由两次 dfs 组成。

令 $size[u]$ 为以 $u$ 为根的子树的结点数,$depth[u]$ 为结点 $u$ 的深度,$fa[u]$ 为结点 $u$ 的父亲。

首先无根树转有根树,以 $1$ 为根,记录每个节点的子节点总数(包括自己) $size[u]$ ,每个节点的父亲 $fa[u]$ ,每个节点到父亲的边权 $f[u]$ ,深度 $depth[u]$ 。这是第一次 dfs。

然后就是第二次 dfs ,一这道题目为栗子,用线段树来维护将树剖分之后的线性结构,线段树每一个结点的值对应结点到其父亲的边权。

令 $bel[i]$ 为 $i$ 结点所属的链的链头,$id[i]$ 为结点 $i$ 的 dfs 序标号。

首先维护一个时间戳 $\textrm{timestamp}$ ,然后 dfs2(int u, int num) 表示从结点 $u$ 开市构建链,结点 $u$ 属于的链编号为 $num$ 。每次 dfs 的时候结点 $id[u] = ++\textrm{timestamp}$ ,dfs 序保证了同一条链中结点的 $id[]$ 是连续的,所以只需要维护一棵线段树就可以了。

然后找结点 $u$ 的儿子中 $size[]$ 最大的结点 $v$ (若不存在,递归到叶子,则返回),然后仍然执行 dfs2(v, num) 使得这一条链尽可能长。其余的儿子节点构成新链,而他们自身则为新链的链头,执行 dfs2(v, v)

然后,构建线段树,因为在一条链上的点的编号是连续的,所以可以在一颗线段树上进行修改或者询问操作。将所有的重链首尾相连放入线段树中,$id[u]$ 为结点 $u$ 在线段树中的编号。

单点修改,很容易,直接在线段树上修改即可。

修区间修改,修改点 $u$ 到点 $v$ 之间路径上的权值,分两种情况讨论:

  • 结点 $u$ 和结点 $v$ 在同一条链上,那么直接在线段树上修改 $id[u]$ 到 $id[v]$ 之间的点的权值。
  • 若结点 $u$ 与结点 $v$ 不在一条链上,那么以便修改,一边将 $u$ 和 $v$ 向同一条重链上靠,直到变成在一条链上的情况。

至于查询操作和修改操作的思想是一样的,这里还是详细的写一下:

单点查询同单点修改。

区间询问:

  • 在同一条链上,直接询问
  • 不在同一条链上,先交换,保证 $u$ 的链头深度大于 $v$ 的链头深度。然后将 $u$ 移动到 $fa[bel[u]]$ 即上一条链,一边移动一边更新询问的答案,直到 $bel[u] = bel[v]$ ,然后同情况一。

单次操作的时间复杂度为 $\Theta(log^{2}n)$ 。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
//  Created by ZJYelizaveta on Tuesday, November 07, 2017 AM09:01:56 CST
//
// Copyright (c) 2017年 ZJYelizaveta. All rights reserved.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
template<typename T> T readIn() {
T x(0), f(1);
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
const int MAX_N = 30000 + 3;
const int INF = 0x3f3f3f3f;
const int MAXNODE = MAX_N * 4;
int n, q;
vector<int> G[MAX_N << 1];
int w[MAX_N];

inline void addEdge(int from, int to) {
G[from].push_back(to);
G[to].push_back(from);
}

namespace segmentTree{
#define mid (((l) + (r)) >> 1)
#define lc ((o) << 1)
#define rc (((o) << 1) + 1)
struct Node{
int l, r, sum, maxVal;
}node[MAXNODE];

void build(int o, int l, int r) {
node[o].l = l, node[o].r = r;
if (r == l) return;
build(lc, l, mid);
build(rc, mid + 1, r);
}

inline void pushUp(int o) {
node[o].sum = node[lc].sum + node[rc].sum;
node[o].maxVal = max(node[lc].maxVal, node[rc].maxVal);
}

inline void modify(int o, int pos, int val) { //change the single point on the segemnt tree;
int l = node[o].l, r = node[o].r;
if (l == r) {
node[o].sum = node[o].maxVal = val;
return;
}
if (pos <= mid) modify(lc, pos, val);
else modify(rc, pos, val);

pushUp(o);
}

int querySum(int o, int L, int R) {
int l = node[o].l, r = node[o].r;
if (l == L && r == R) return node[o].sum;

if (R <= mid) return querySum(lc, L, R);
else if (L > mid) return querySum(rc, L, R);
else return querySum(lc, L, mid) + querySum(rc, mid + 1, R);
}

int queryMax(int o, int L, int R) {
int l = node[o].l, r = node[o].r;
if (l == L && R == r) return node[o].maxVal;//

if (R <= mid) return queryMax(lc, L, R);
else if (L > mid) return queryMax(rc, L, R);
else return max(queryMax(lc, L, mid), queryMax(rc, mid + 1, R));
}
}
using namespace segmentTree;

namespace heavyLightDecomposition {
int size[MAX_N], depth[MAX_N], fa[MAX_N];

inline void dfs1(int u, int f) {
size[u] = 1;
depth[u] = f == 0 ? 0 : depth[f] + 1; fa[u] = f;

for (int i = 0; i < (int)G[u].size(); ++i) {
int v = G[u][i];
if (v == f) continue;

dfs1(v, u);
size[u] += size[v];
}
}

int bel[MAX_N], id[MAX_N], timeStamp = 0;
inline void dfs2(int u, int num) {
bel[u] = num, id[u] = ++timeStamp;

int Max = 0, idx = 0;
for (int i = 0; i < (int)G[u].size(); ++i) {
int v = G[u][i];
if (v != fa[u] && size[v] > Max) {
Max = size[v]; idx = v;
}
}

if (Max == 0) return;
dfs2(idx, num);

for (int i = 0; i < (int)G[u].size(); ++i) {
int v = G[u][i];
if (v != fa[u] && v != idx) dfs2(v, v);
}
}
}
using namespace heavyLightDecomposition;

namespace solve {
inline int solveSum(int a, int b) {
int sum = 0;
while (bel[a] != bel[b]) {
if (depth[bel[a]] < depth[bel[b]]) swap(a, b);
sum += querySum(1, id[bel[a]], id[a]);
a = fa[bel[a]];
}
if (depth[a] > depth[b]) swap(a, b);
sum += querySum(1, id[a], id[b]);
return sum;
}

inline int solveMax(int a, int b) {
int maxVal = -INF;
while (bel[a] != bel[b]) {
if (depth[bel[a]] < depth[bel[b]]) swap(a, b);
maxVal = max(maxVal, queryMax(1, id[bel[a]], id[a]));
a = fa[bel[a]];
}
if (depth[a] > depth[b]) swap(a, b);
maxVal = max(maxVal, queryMax(1, id[a], id[b]));
return maxVal;
}
}
using namespace solve;

char opt[10];
int main()
{
n = readIn<int>();
for (int i = 1; i <= n - 1; ++i) {
int u = readIn<int>(), v = readIn<int>();
addEdge(u, v);
}
for (int i = 1; i <= n; ++i) w[i] = readIn<int>();

fa[1] = 0;
dfs1(1, 0);
dfs2(1, 1);
// for (int i = 1; i <= n; ++i) printf("%d %d\n", i, bel[i]);
build(1, 1, n);
for (int i = 1; i <= n; ++i) modify(1, id[i], w[i]);


q = readIn<int>(); // printf("%d\n", q);
while (q--) {
scanf("%s", opt);
// printf("%s\n", opt);

if (opt[0] == 'C') {
int u = readIn<int>(), val = readIn<int>();
w[u] = val;
modify(1, id[u], val);
}
else {
int u = readIn<int>(), v = readIn<int>();
if (opt[1] == 'M') printf("%d\n", solveMax(u, v));
else printf("%d\n", solveSum(u, v));
}
}

return 0;
}
Compartir