「AGC047D」Twin Binary Trees

$Link$

有两棵高度为 $h$ 的完全二叉树,每棵有 $2^h-1$ 个节点,称为红树和蓝树。它们的根节点编号为 $1$,并且 $x$ 的左儿子为 $2x$,右儿子为 $2x+1$。父亲和儿子之间有一条无向边连接。

红蓝树的叶子结点构成了一个双射,对应的节点间有一条无向边相连,称这些边为特殊边。

我们定义一个好的环,当且仅当这是一个简单环并且正好经过了两条特殊边。一个环的权值就是它经过的所有点的编号的乘积。

现在给你这个双射,问在这个由两棵完全二叉树和特殊边构成的图上的所有的好的环的权值和。

对 $10^9+7$ 取模。

$2\le h\le18$。

这里我们设红树的叶子结点 $x$ 向蓝树上的叶子结点 $x’$ 连边。

显然,一个好的环是红树上的两个叶子节点 $u$ 和 $v$ 间的路径,拼上蓝树上的两个叶子结点 $u’$ 和 $v’$ 间的路径,再加上两条特殊边 $u\to u’,v’\to v$。

对于一个好的环,它的权值可以分成红树和蓝树两部分计算。对于每部分,又可以再分成两段:$u$ 到 $lca$ 和 $v$ 到 $lca$。

于是我们可以枚举红树上两个点的 $lca$,对于其左子树的每个叶子结点 $u$,在蓝树上所有可能作为 $lca$ 的节点,也即从 $u’$ 到蓝树的根上打上标记。再对于右子树的每个叶子结点 $v$,遍历 $v’$ 到蓝树的根,统计答案。

计算一条路径上的权值直接预处理前缀积即可。

注意要稍微考虑一下细节,不要重复计算。

时间复杂度 $O(2^hh^2)$。

$code$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <vector>
using namespace std;
inline int read()
{
int f = 1, x = 0;
char ch;

do{
ch = getchar();
if (ch == '-')
f = -1;
}while(ch < '0' || ch > '9');
do{
x = x * 10 + ch - '0';
ch = getchar();
}while(ch >= '0' && ch <= '9');
return f * x;
}
const int N = 1 << 18;
const int mod = 1e9 + 7;

int n;
vector<int> lleave[1 << 18], rleave[1 << 18];
int to[N + 1];
int val[N + 1], inv[N + 1], ssum[N + 1];
int stk[N + 1], top;
int ans;

inline int pow(int x, int y)
{
int sum = 1;

while (y) {
if (y & 1)
sum = 1ll * sum * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return sum;
}
int main()
{
n = read();
for (int i = 1 << (n - 1); i < 1 << n; i++) {
to[i] = read() + (1 << (n - 1)) - 1;
int u = i >> 1, v = i;
while (u) {
if ((u << 1) == v)
lleave[u].push_back(i);
else
rleave[u].push_back(i);
v = u;
u >>= 1;
}
}
val[1] = 1;
inv[0] = inv[1] = 1;
for (int i = 2; i < 1 << n; i++) {
val[i] = 1ll * val[i >> 1] * i % mod;
inv[i] = pow(val[i], mod - 2);
}

for (int w = 1; w < 1 << (n - 1); w++) {
for (vector<int>::iterator it = lleave[w].begin(); it != lleave[w].end(); it++) {
int u = *it, vval = 1ll * val[u] * inv[w] % mod, v = to[u];
while (v) {
if (!ssum[v])
stk[++top] = v;
ssum[v] = (ssum[v] + vval) % mod;
vval = 1ll * vval * v % mod;
v >>= 1;
}
}
for (vector<int>::iterator it = rleave[w].begin(); it != rleave[w].end(); it++) {
int u = *it, vval = 1ll * val[u] * inv[w >> 1] % mod, v = to[u], la = 0;
while (v) {
vval = 1ll * vval * v % mod;
ans = (ans + 1ll * vval * ((ssum[v] - 1ll * ssum[la] * la % mod + mod) % mod) % mod) % mod;
la = v;
v >>= 1;
}
}
while (top)
ssum[stk[top--]] = 0;
}

printf("%d\n", ans);
return 0;
}