题意
给定 $n$ 个非负整数,选一个子集,使得子集的平均值减去中位数最大。
求这个最大值。
$\rm{Subtask}$ $1-\mathcal{O}(n^2)$
因为需要考虑中位数,容易想到按照子集大小的奇偶性分类讨论。
子集大小为奇数
先来考虑奇数的情况。
容易想到 $\mathcal{O}(n)$ 枚举中位数。
如上图,对数列按照升序排序。选定中位数后,我们每次从中位数左侧添加一个点,再从中位数右侧添加一个点。又因为我们想要子集的平均值最大,所以贪心地,每次分别从两侧取一个最大的数,也就是取最靠右的点,直到有一个区间被取完。
比如上图中,$(4, 10)$,$(3, 9)$,$(2, 8)$ 两两为一组,每次分别被取出。
于是对于每个中位数,只需要 $\mathcal{O}(n)$ 枚举两侧添加点的个数,就可以确定出选取的情况,求出平均值与中位数的差。
由于方案已经确定,直接输出即可。
这样可以得到整体复杂度为 $\mathcal{O}(n^2)$ 的算法,期望得分未知。
子集大小为偶数
结论
对于大小为偶数的子集,观察可以得到结论:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优,下面进行证明。
分析
按照上面的贪心策略任取一个大小为奇数的子集,将其升序排序,设为 $\{a_1, a_2, \cdots a_{2n-1}\}$.
现在向子集中添加一个数,使其大小变为偶数。根据上面提到的贪心策略,新加入的数 $x$ 有以下两种情况:
- $x$ 在 $a_1$ 左侧,即 $x\leq a_1$,
- 紧邻中位数右侧,即 $a_n\leq x\leq a_{n+1}$.
在上面的图里,可以形象地看到,$x$ 的第一种情况是选择 $1$,第二种情况是在 $6, 7$ 中选一个。
第一种情况,显然不会使答案变大,因为 $x$ 对平均值的影响太大了。
对于第二种情况,下面进行详细证明。
证明
为了方便,我们把加入 $x$ 后的数列重新升序排序,不妨设为 $\{b_1, b_2, \cdots b_{2n}\}$.
加入 $x$ 前的答案为:
A=\frac{\sum_{i=1}^{2n}b_i-b_{n+1}}{2n-1}-b_{n}
$$
加入后的答案为:
B=\frac{\sum_{i=1}^{2n}b_i}{2n}-\frac{b_{n} + b_{n + 1}}{2}
$$
两者作差并化简后得到:
A-B=\frac{\sum_{i=1}^{2n}b_i}{2n(2n-1)}+\frac{b_{n+1}-b_n}{2}-\frac{b_{n+1}}{2n-1}
$$
又因为 $B\geq 0$,所以:
\frac{\sum_{i=1}^{2n}b_i}{2n}\geq \frac{b_n+b_{n+1}}{2}
$$
即:
\frac{\sum_{i=1}^{2n}b_i}{2n(2n-1)}\geq \frac{b_n+b_{n+1}}{2(2n-1)}
$$
代入 $A-B$ 中得到:
\begin{aligned}
A-B&\geq \frac{b_n+b_{n+1}}{2(2n-1)}+\frac{b_{n+1}-b_n}{2}-\frac{b_{n+1}}{2n-1}\\
&=\frac{b_n+b_{n+1}+(2n-1)(b_{n+1}-b_n)-2b_{n+1}}{2(2n-1)}\\
&=\frac{b_n+b_{n+1}+2n\cdot b_{n+1}-2n\cdot b_n-b_{n+1}+b_n-2b_{n+1}}{2(2n-1)}\\
&=\frac{(n-1)(b_{n+1}-b_n)}{(2n-1)}\geq 0
\end{aligned}
$$
于是我们证明了:加入 $x$ 前的答案一定不小于加入 $x$ 后的答案,也就是:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优。
$\mathcal{O}(n^2)$ 代码
以下代码输出的是:平均值与中位数的差的最大值
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3f;
inline int read() {}//快读
int n, a[MAXN], sum[MAXN];
signed main() {
n = read();
for(int i = 1; i <= n; i++) a[i] = read();
sort(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
double ans = (-1.0) * inf;
for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右
double mmax = (double)(a[i]);//只取中位数
for(int j = 1; j <= min(i - 1, n - i); j++) {//subtask1 暴力枚举两侧加2*j个点
int add_l = sum[i - 1] - sum[i - j - 1];
int add_r = sum[n] - sum[n - j];
mmax = max(mmax, (double)(add_l + add_r + a[i]) / (double)(2 * j + 1));
}
ans = max(ans, mmax - (double)(a[i]));
}
printf("%.2lf\n", ans);
return 0;
}
$\rm{Subtask}$ $2-\mathcal{O}(n\log n)$
考虑如何把时间复杂度优化到 $\mathcal{O}(n\log n)$ 级别。
$\mathcal{O}(n^2)$ 的做法中,我们先枚举中位数 $mid$,再枚举两侧添加的元素个数 $len$.
在此基础上,设第 $i$ 次添加的两个元素下标分别为 $x_i$ 和 $y_i$.
当确定中位数时,答案只与平均值有关,考虑计算添加 $(x_i, y_i)$ 对平均值的贡献。
贡献,也就是添加这两个数之后,平均值的变化量,若平均值变大,则贡献为正,反之为负。
显然地,随着 $i$ 的增长,$a_{x_i}+a_{y_i}$ 会持续下降,对平均值的贡献也会越来越少。
我们最后一次选取的两个元素 $x_{len}, y_{len}$,必定满足贡献大于 $0$,并且是最接近 $0$ 的。
因此我们可以通过二分,找到贡献紧邻最接近0的数。
这样就确定了选择哪些数,前缀和计算答案即可。
具体实现时,我们不需要准确地计算出每两个元素的贡献。若加入两个元素后的平均值大于原数列的平均值,则贡献为正,否则贡献为负。这部分需要利用前缀和求出数列平均值。
对于输出方案,已经确定,直接输出即可。
这样就得到了 $\mathcal{O}(n\log n)$ 的算法,期望得分 $\rm{100pts}$.
$\rm{100pts}$ 代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3f;
inline int read() {
int ret = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch <= '9' && ch >= '0') {
ret = ret * 10 + ch - '0';
ch = getchar();
}
return ret;
}
int n;
int a[MAXN], sum[MAXN];
signed main() {
n = read();
for(int i = 1; i <= n; i++) a[i] = read();
sort(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
double ans = (-1.0) * inf; int ans_cnt = 0, ans_mid = 0;
for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右
int l = 1, r = min(i - 1, n - i), mid;
double mmax = (double)(a[i]); int cnt = 0;
while(l <= r) {
mid = (l + r) >> 1;
int sum1 = sum[i - 1] - sum[i - mid] + sum[n] - sum[n - mid + 1];
//j = mid - 1,即加入前的平均值
int sum2 = sum[i - 1] - sum[i - mid - 1] + sum[n] - sum[n - mid];
//j = mid,加入两个元素后的平均值
if((double)(sum1 + a[i]) / (double)(2 * mid - 1) < (double)(sum2 + a[i]) / (double)(2 * mid + 1)) {
if((double)(sum2 + a[i]) / (double)(2 * mid + 1) > mmax) {
mmax = (double)(sum2 + a[i]) / (double)(2 * mid + 1);
cnt = mid;
}
l = mid + 1;
} else {
r = mid - 1;
}
}
if(mmax - double(a[i]) > ans) {
ans = mmax - (double)(a[i]);
ans_cnt = cnt;
ans_mid = i;
}
}
cout << ans_cnt * 2 + 1 << endl;
for(int i = ans_mid - ans_cnt; i <= ans_mid; i++) cout << a[i] << " ";
for(int i = n - ans_cnt + 1; i <= n; i++) cout << a[i] << " ";
return 0;
}```