$Link$
给定一个长度为 $n$ 的排列 $a$,定义一个序列是合法的当且仅当被标记的数单调递增。
一开始,$a$ 中的每个数均未被标记,每次可以等概率选择一个未被标记过,且标记后序列合法的数标记,当没有可以选择的数时停止操作,求期望的标记个数,对 $10^9+7$ 取模。
$1\le n\le2000$。
考虑我们只考虑区间 $[l,r]$ 的期望,并且已经选了 $a_l,a_r$ 两个数,且 $a_l<a_r$,设还可以选的数为 $s_1,\cdots s_k$,满足 $l<s_i<r,a_l<a_{s_i}<a_r$,那期望就是 $f_{l,r}=1+\frac{1}{k}\sum_{i=1}^kf_{l,s_i}+f_{s_i,r}$。
直接使用树状数组优化这个 $\sum$ 即可,时间复杂度 $O(n^2\log n)$。
$code$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| #include <iostream> #include <cstdlib> #include <cstdio> using namespace std; inline int read() { int f = 1, x = 0; char ch;
do{ ch = getchar(); if (ch == '-') f = -1; }while(ch < '0' || ch > '9'); do{ x = x * 10 + ch - '0'; ch = getchar(); }while(ch >= '0' && ch <= '9'); return f * x; } const int N = 3e3; const int mod = 1e9 + 7;
int n; int a[N + 1], inv[N + 1]; int f[N + 1][N + 1]; struct BIT { int tr[N + 2]; inline int lowbit(int x) { return x & (-x); } inline void update(int pos, int val) { while (pos <= n) { tr[pos] = (tr[pos] + val) % mod; pos += lowbit(pos); } return; } inline int query(int pos) { int ans = 0;
while (pos) { ans = (ans + tr[pos]) % mod; pos -= lowbit(pos); } return ans; } } bitl[N + 1], bitr[N + 1], bitsum[N + 1];
inline void init() { inv[1] = 1; for (int i = 2; i <= N; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; return; } inline void dp() { for (int i = 1; i <= n; i++) f[i][i] = 1; for (int len = 2; len <= n; len++) { for (int l = 1, r; l + len - 1 <= n; l++) { r = l + len - 1; if (a[l] > a[r]) continue; f[l][r] = 1ll * ((1ll * bitl[l].query(a[r]) - bitl[l].query(a[l]) + bitr[r].query(a[r]) - bitr[r].query(a[l]) + bitsum[l].query(a[r]) - bitsum[l].query(a[l])) + mod) % mod * inv[(bitsum[l].query(a[r]) - bitsum[l].query(a[l]) + mod) % mod] % mod; bitl[l].update(a[r], f[l][r]); bitr[r].update(a[l], f[l][r]); bitsum[l].update(a[r], 1); } } return; } int main() { init(); n = read(); a[1] = 1; for (int i = 1; i <= n; i++) a[i + 1] = read() + 1; a[n + 2] = n + 2; n += 2;
dp();
printf("%d\n", f[1][n]); return 0; }
|