「TOKIOMARINE2020E」O(rand)

$Link$

给你一个长度为 $n$ 的非负整数序列 $a$,保证其中的数两两不同。求从中选择 $1\sim k$ 个数,使得这些数与起来等于 $s$,或起来等于 $t$ 的方案数。

$1\le n\le50,0\le a_i,s,t<2^{18},\forall i\not=j,a_i\not=a_j$。

设一个数的第 $i$ 个二进制位为 $x_i’$。

对于每一个二进制位:

  • 如果 $s’_i=0,t’_i=0$,那么所有 $a’_i=1$ 的数就没有用了。
  • 如果 $s’_i=1,t’_i=1$,那么所有 $a’_i=0$ 的数就没有用了。
  • 如果 $s’_i=1,t’_i=0$,问题无解。

剩下只要考虑 $s’_i=0,t’_i=1$ 的位了,这意味着必须分别选至少一个 $a’_i=0$ 和 $a’_i=1$ 的数。

转换一下,我们把这些位置标 $1$,其它位置标 $0$,得到状态 $S$,答案就是对于 $S$ 的子集这些位全选 $0$ 或者全选 $1$ 的方案数,乘上容斥系数求和。

对上述的方案数再转换一下,就变成从 $a_i&x$ 中选取 $1\sim k$ 个相等的数的方案数。

那只需要预处理组合数,开个桶做就好了。时间复杂度 $O(2^{18}n)$。

$code$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <iostream>
#include <cstdlib>
#include <cstdio>
#define ll long long
using namespace std;
inline int read()
{
int f = 1, x = 0;
char ch;

do{
ch = getchar();
if (ch == '-')
f = -1;
}while(ch < '0' || ch > '9');
do{
x = x * 10 + ch - '0';
ch = getchar();
}while(ch >= '0' && ch <= '9');
return f * x;
}
const int N = 50;

int n, k, s, t, S, sum;
int pow2[18];
int a[N + 1], buc[1 << 18];
ll f[N + 1][N + 1], C[N + 1];
ll ans;

inline void init()
{
f[0][0] = 1;
for (int i = 1; i <= N; i++) {
f[i][0] = 1;
for (int j = 1; j <= i; j++)
f[i][j] = f[i - 1][j] + f[i - 1][j - 1];
}
for (int i = 1; i <= N; i++)
for (int j = 1; j <= i && j <= k; j++)
C[i] += f[i][j];
for (int i = N; i >= 1; i--)
C[i] -= C[i - 1];
return;
}
int main()
{
n = read();
k = read();
s = read();
t = read();
init();
pow2[0] = 1;
for (int i = 1; i < 18; i++)
pow2[i] = 2 * pow2[i - 1];
for (int i = 0; i < 18; i++) {
if (t & pow2[i] && !(s & pow2[i])) {
S |= pow2[i];
sum++;
}
}
for (int i = 1; i <= n; i++) {
a[i] = read();
for (int j = 0; j < 18; j++) {
if (s & pow2[j] && !(a[i] & pow2[j])) {
i--;
n--;
continue;
}
if (!(t & pow2[j]) && a[i] & pow2[j]) {
i--;
n--;
continue;
}
}
}
if (!n) {
printf("0\n");
return 0;
}

while (true) {
ll per = 0;
for (int i = 1; i <= n; i++)
per += C[++buc[a[i] & S]];
if (sum & 1)
ans -= per;
else
ans += per;
for (int i = 1; i <= n; i++)
buc[a[i] & S] = 0;
if (!S)
break;
S = (S - 1) & (s ^ t);
sum = 0;
for (int i = 0; i < 18; i++)
if (S & pow2[i])
sum++;
}

printf("%lld\n", ans);
return 0;
}