题目地址:P5664 Emiya 家今天的饭 – 洛谷 | 计算机科学教育新生态
题意
一共 $n$ 种烹饪方法和 $m$ 种主要食材,每道菜对应一种方法和一种食材。使用第 $i$ 种方法和 第 $j$ 种食材,可以做出 $a_{i, j}$ 道不同的菜。求有多少个集合满足如下条件:
- 非空
- 每道菜烹饪方法互不相同
- 集合中每种主要食材的菜数不超过集合大小的一半
数据范围:$n\leq 100, m\leq 2000$.
解法
容易想到DP,不容易想到状态。
期望复杂度 $n^2m$.
$\rm{Subtask}$ $\rm{1}$
对于 $32\%$ 的数据,暴搜即可,与正解思路无关…
$\rm{Subtask}$ $\rm{2}$
对于 $84\%$ 的数据,期望复杂度 $mn^3$ .
性质 $1, 2$ 很好满足,如果没有性质 $3$ 这题随便做,所以具体分析性质 $3$.
容易想到容斥。对于一个满足性质 $1, 2$ 但不满足性质 $3$ 的集合,显然有且仅有一种主要食材超过了集合大小的一半。所以答案等于:
\text{满足性质 }1, 2 \text{ 的方案数量}-\text{某一列超过了集合大小一半的方案数量}
$$
分开解决这两个问题。
第一部分:求满足性质 $1, 2$ 的方案数量
根据乘法原理,满足性质 $1, 2$ 的方案数量为:
\left[\prod_{i=1}^{n}(\sum_{j = 1}^{m}a_{i, j} + 1)\right]-1
$$
第二部分:求满足性质 $1, 2$,不满足性质 $3$ 的方案数量
由于仅有一种主要食材超过集合大小的一半,考虑枚举这种主要食材。
假设枚举到不合法的食材为 $t$,那么此时我们只关心:选择的方法是否对应食材 $t$.
接下来设计状态。$\operatorname{DP}_{i, j, k}$ 表示对于已经选定的食材 $t$,前 $i$ 种方法中,有 $j$ 种方法与食材 $t$ 对应,有 $k$ 种方法不与食材 $t$ 对应,那么有转移方程:
\operatorname{DP}_{i, j, k}=\operatorname{DP}_{i-1, j, k}+a_{i, t}\times\operatorname{DP}_{i-1, j-1, k}+(\sum_{r=1}^{m}a_{i, r}-a_{i, t})\times \operatorname{DP}_{i-1, j, k-1}
$$
这样我们可以很方便的求出不合法方案数:
\sum_{1\leq k<j\leq n} \operatorname{DP}_{n, j, k}
$$
第一部分复杂度为 $\mathcal{O}(mn)$,第二部分 $\sum_{r=1}^{m}a_{i, r}$ 可以 $\mathcal{O}(mn)$ 预处理,DP 复杂度 $\mathcal{O}(mn^3)$.
$\rm{Subtask}$ $\rm{3}$
考虑如何优化 $\rm{Subtask}$ $\rm{2}$.
上面的不合法方案数统计方法:
\sum_{1\leq k<j\leq n} \operatorname{DP}_{n, j, k}
$$
其中 $k, j$ 满足的条件是 $k<j$,移项可以得到 $j-k>0$. 所以判断一个方案是否合法,只需要判断 $j-k$ 的符号,于是将 $\rm{Subtask}$ $\rm{2}$ 状态中的 $j$ 和 $k$ 压为一维 $j-k$ 的值,复杂度为 $\mathcal{O}(n^2m)$,转移类似。
于是此时状态:$\operatorname{DP}_{i, j}$ 表示对于已经选定的食材 $t$,前 $i$ 种方法中,有 $a$ 种方法与食材 $t$ 对应,有 $b$ 种方法不与食材 $t$ 对应,$j = b – a$.
转移方程:
\operatorname{DP}_{i, j} = \operatorname{DP}_{i-1, j} + a_{i,t}\times \operatorname{DP}_{i-1, j-1}+(\sum_{r=1}^{m}a_{i, r}-a_{i, t})\times \operatorname{DP}_{i-1, j+1}
$$
代码
#include <bits/stdc++.h>
#define int long long
#define MAXN 108
#define MAXM 2008
using namespace std;
const int mod = 998244353;
int read() {
int ret = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch <= '9' && ch >= '0') ret = ret * 10 + ch - '0', ch = getchar();
return ret;
}
int n, m;
int a[MAXN][MAXM], sum[MAXN];
int dp[MAXN][2 * MAXN];
signed main() {
n = read(); m = read();
int ans = 1;
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= m; j++) {
a[i][j] = read();
sum[i] = (sum[i] + a[i][j]) % mod;
}
ans = ans * (sum[i] + 1) % mod;
}
for(int t = 1; t <= m; t++) {
memset(dp, 0, sizeof(dp));
dp[0][n] = 1;
for(int i = 1; i <= n; i++) {
for(int j = n - i; j <= n + i; j++) {
dp[i][j] = (dp[i - 1][j] + a[i][t] * dp[i - 1][j - 1] % mod + (sum[i] - a[i][t]) * dp[i - 1][j + 1] % mod) % mod;
}
}
for(int j = 1; j <= n; j++) {
ans = ((ans - dp[n][n + j]) % mod + mod) % mod;
}
}
cout << ((ans - 1) % mod + mod) % mod << endl;
return 0;
}