「ARC108E」Random IS

$Link$

给定一个长度为 $n$ 的排列 $a$,定义一个序列是合法的当且仅当被标记的数单调递增。

一开始,$a$ 中的每个数均未被标记,每次可以等概率选择一个未被标记过,且标记后序列合法的数标记,当没有可以选择的数时停止操作,求期望的标记个数,对 $10^9+7$ 取模。

$1\le n\le2000$。

考虑我们只考虑区间 $[l,r]$ 的期望,并且已经选了 $a_l,a_r$ 两个数,且 $a_l<a_r$,设还可以选的数为 $s_1,\cdots s_k$,满足 $l<s_i<r,a_l<a_{s_i}<a_r$,那期望就是 $f_{l,r}=1+\frac{1}{k}\sum_{i=1}^kf_{l,s_i}+f_{s_i,r}$。

直接使用树状数组优化这个 $\sum$ 即可,时间复杂度 $O(n^2\log n)$。

$code$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <iostream>
#include <cstdlib>
#include <cstdio>
using namespace std;
inline int read()
{
int f = 1, x = 0;
char ch;

do{
ch = getchar();
if (ch == '-')
f = -1;
}while(ch < '0' || ch > '9');
do{
x = x * 10 + ch - '0';
ch = getchar();
}while(ch >= '0' && ch <= '9');
return f * x;
}
const int N = 3e3;
const int mod = 1e9 + 7;

int n;
int a[N + 1], inv[N + 1];
int f[N + 1][N + 1];
struct BIT {
int tr[N + 2];
inline int lowbit(int x)
{
return x & (-x);
}
inline void update(int pos, int val)
{
while (pos <= n) {
tr[pos] = (tr[pos] + val) % mod;
pos += lowbit(pos);
}
return;
}
inline int query(int pos)
{
int ans = 0;

while (pos) {
ans = (ans + tr[pos]) % mod;
pos -= lowbit(pos);
}
return ans;
}
} bitl[N + 1], bitr[N + 1], bitsum[N + 1];

inline void init()
{
inv[1] = 1;
for (int i = 2; i <= N; i++)
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
return;
}
inline void dp()
{
for (int i = 1; i <= n; i++)
f[i][i] = 1;
for (int len = 2; len <= n; len++) {
for (int l = 1, r; l + len - 1 <= n; l++) {
r = l + len - 1;
if (a[l] > a[r])
continue;
f[l][r] = 1ll * ((1ll * bitl[l].query(a[r]) - bitl[l].query(a[l]) + bitr[r].query(a[r]) - bitr[r].query(a[l]) + bitsum[l].query(a[r]) - bitsum[l].query(a[l])) + mod) % mod * inv[(bitsum[l].query(a[r]) - bitsum[l].query(a[l]) + mod) % mod] % mod;
bitl[l].update(a[r], f[l][r]);
bitr[r].update(a[l], f[l][r]);
bitsum[l].update(a[r], 1);
}
}
return;
}
int main()
{
init();
n = read();
a[1] = 1;
for (int i = 1; i <= n; i++)
a[i + 1] = read() + 1;
a[n + 2] = n + 2;
n += 2;

dp();

printf("%d\n", f[1][n]);
return 0;
}