「NOI2021」轻重边 (edge)

$Link$

有一棵 $n$ 个节点的树,每条边都被染成黑色或白色,一开始所有的边都是白色。

有 $m$ 次操作,每次操作为以下两种操作之一:

  • 给定两个点 $u,v$,对于 $u,v$ 间的路径上的所有的点 $x$(包括 $u$ 和 $v$),把所有和 $x$ 相邻的边染成白色,然后再把 $u,v$ 间的路径上所有的边染成黑色。
  • 给定两个点 $u,v$,问 $u,v$ 间的路径上有几条黑色的边。

$T$ 组数据。

$T\le3,1\le n,m\le10^5$。

本题有很精妙的做法(比如 $LCT$),这里提供一个复杂度正确且对于这类有关毛毛虫的问题十分有效的做法,思路也并不复杂。

看到题目第一反应肯定是树剖,但是本题的特殊点在于它不是链的修改而是毛毛虫的修改。

重链剖分的关键想法,就是把树划分为多条重链,对于重链上的点,对他们重新编号使得其编号连续,再用线段树等数据结构维护。

那对于毛毛虫,我们思考是否有办法在重链剖分的基础上,不仅让重链的编号连续,还让重链毛毛虫的编号也连续?

很显然这是不可能的,因为两个重链毛毛虫会在一条重链的链顶形成交点,考虑特别维护这个点?

方法已经很显然了:

  • 对于树上的每条重链,除了链顶之外从浅到深编号。
  • 对于所有的毛毛虫节点,同样从浅到深编号。

在操作一条重链时,由于链顶的编号特殊,因此我们存储每条重链的第二个节点,把链分成两段处理即可。

同时我们需要快速查询重链上一段毛毛虫的编号区间,对每个点维护:

  • 从链顶到这个点之前,毛毛虫节点编号的最大值。
  • 这个点的毛毛虫节点编号后,毛毛虫节点编号的最大值。

修改时先改重链,如果重链涉及到了链顶需要把它拆成一个点加一条链,然后查询该段重链对应的毛毛虫节点的编号区间并修改之。

注意在轻边跳到重边上时,不要漏掉重边的下一条边,这条边同样也是一条毛毛虫边。

对于本题,再加上支持一个区间对 $0$ 取 $\operatorname{min}$ 和区间对 $1$ 取 $\operatorname{max}$ 的线段树即可。

本题支持重复的修改。如果是毛毛虫的加或乘等需要不重不漏的修改的操作,需要进一步分类讨论容斥。

时间复杂度 $O(m\log^2n)$。

$code$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <cstring>
#define lson (now << 1)
#define rson ((now << 1) + 1)
#define mid ((l + r) >> 1)
using namespace std;
inline int read()
{
int f = 1, x = 0;
char ch;

do{
ch = getchar();
if (ch == '-')
f = -1;
}while(ch < '0' || ch > '9');
do{
x = x * 10 + ch - '0';
ch = getchar();
}while(ch >= '0' && ch <= '9');
return f * x;
}
const int N = 1e5;

int tt;
int n, m;
struct Edge {
int to, next;
} edge[N * 2 + 1];
int start[N + 1], tot;
struct Seg {
int tr[N * 4 + 1], tag[N * 4 + 1];
inline void pushup(int now)
{
tr[now] = tr[lson] + tr[rson];
return;
}
inline void build(int now, int l, int r)
{
tr[now] = 0;
tag[now] = -1;
if (l == r)
return;
build(lson, l, mid);
build(rson, mid + 1, r);
return;
}
inline void pushdown(int now, int l, int r)
{
if (tag[now] != -1) {
tag[lson] = tag[rson] = tag[now];
tr[lson] = (mid - l + 1) * tag[now];
tr[rson] = (r - mid) * tag[now];
tag[now] = -1;
}
return;
}
inline void update(int now, int l, int r, int L, int R, int type)
{
if (L > R)
return;
if (l >= L && r <= R) {
tag[now] = type;
tr[now] = (r - l + 1) * tag[now];
return;
}
pushdown(now, l, r);
if (mid >= L)
update(lson, l, mid, L, R, type);
if (mid < R)
update(rson, mid + 1, r, L, R, type);
pushup(now);
return;
}
inline int query(int now, int l, int r, int L, int R)
{
if (L > R)
return 0;
if (l >= L && r <= R)
return tr[now];
pushdown(now, l, r);
int sum = 0;

if (mid >= L)
sum += query(lson, l, mid, L, R);
if (mid < R)
sum += query(rson, mid + 1, r, L, R);
pushup(now);
return sum;
}
} seg;
int id[N + 1], sz[N + 1], wson[N + 1], dep[N + 1], fa[N + 1], top[N + 1], cnt;
vector<int> link[N + 1];
queue<int> q;
int st[N + 1], nxt[N + 1], be[N + 1], ed[N + 1];

inline void addedge(int u, int v)
{
edge[++tot] = { v, start[u] };
start[u] = tot;
edge[++tot] = { u, start[v] };
start[v] = tot;
return;
}
inline void dfs1(int u)
{
sz[u] = 1;
dep[u] = dep[fa[u]] + 1;
wson[u] = 0;
for (int i = start[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u]) {
fa[v] = u;
dfs1(v);
sz[u] += sz[v];
if (sz[v] > sz[wson[u]])
wson[u] = v;
}
}
return;
}
inline void dfs2(int u)
{
link[top[u]].push_back(u);
if (wson[u]) {
top[wson[u]] = top[u];
dfs2(wson[u]);
}
for (int i = start[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v != fa[u] && v != wson[u]) {
top[v] = v;
dfs2(v);
}
}
return;
}
inline void update(int p, int q)
{
int u = p, v = q;

while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]])
swap(u, v);
if (nxt[v])
seg.update(1, 1, n, id[nxt[v]], id[nxt[v]], 0);
seg.update(1, 1, n, be[top[v]], ed[v], 0);
v = fa[top[v]];
}
if (dep[u] > dep[v])
swap(u, v);
if (nxt[v])
seg.update(1, 1, n, id[nxt[v]], id[nxt[v]], 0);
seg.update(1, 1, n, be[u], ed[v], 0);
seg.update(1, 1, n, id[u], id[u], 0);
u = p;
v = q;
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]])
swap(u, v);
if (top[v] != v)
seg.update(1, 1, n, st[top[v]], id[v], 1);
seg.update(1, 1, n, id[top[v]], id[top[v]], 1);
v = fa[top[v]];
}
if (u == v)
return;
if (dep[u] > dep[v])
swap(u, v);
if (top[u] == u) {
if (st[u])
seg.update(1, 1, n, st[u], id[v], 1);
} else {
seg.update(1, 1, n, id[u] + 1, id[v], 1);
}
return;
}
inline int query(int u, int v)
{
int sum = 0;

while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]])
swap(u, v);
if (top[v] != v)
sum += seg.query(1, 1, n, st[top[v]], id[v]);
sum += seg.query(1, 1, n, id[top[v]], id[top[v]]);
v = fa[top[v]];
}
if (u == v)
return sum;
if (dep[u] > dep[v])
swap(u, v);
if (top[u] == u) {
if (st[u])
sum += seg.query(1, 1, n, st[u], id[v]);
} else {
sum += seg.query(1, 1, n, id[u] + 1, id[v]);
}
return sum;
}
int main()
{
freopen("edge.in", "r", stdin);
freopen("edge.out", "w", stdout);
tt = read();

while (tt--) {
n = read();
m = read();
memset(start, 0, sizeof(int) * (N + 1));
tot = cnt = 0;
seg.build(1, 1, n);
memset(nxt, 0, sizeof(int) * (N + 1));
for (int i = 1; i <= n; i++)
link[i].clear();
for (int i = 1; i < n; i++)
addedge(read(), read());
dfs1(1);
top[1] = 1;
dfs2(1);
q.push(1);
id[1] = cnt = 1;
while (!q.empty()) {
int u = q.front();
q.pop();
if (link[u].size() > 1)
st[u] = cnt + 1;
for (int i = 0; i < link[u].size(); i++) {
int v = link[u][i];
if (i)
id[v] = ++cnt;
if (i != link[u].size() - 1)
nxt[v] = link[u][i + 1];
}
for (int i = 0; i < link[u].size(); i++) {
int v = link[u][i];
be[v] = cnt + 1;
for (int j = start[v]; j; j = edge[j].next) {
int w = edge[j].to;
if (top[w] != top[v] && w != fa[v]) {
q.push(w);
id[w] = ++cnt;
}
}
ed[v] = cnt;
}
}
for (int i = 1; i <= m; i++) {
int opt = read(), u = read(), v = read();
if (opt == 1)
update(u, v);
else
printf("%d\n", query(u, v));
}
}

return 0;
}