题意#
给定 个非负整数,选一个子集,使得子集的平均值减去中位数最大。
求这个最大值。
#
因为需要考虑中位数,容易想到按照子集大小的奇偶性分类讨论。
子集大小为奇数#
先来考虑奇数的情况。
容易想到 枚举中位数。
如上图,对数列按照升序排序。选定中位数后,我们每次从中位数左侧添加一个点,再从中位数右侧添加一个点。又因为我们想要子集的平均值最大,所以贪心地,每次分别从两侧取一个最大的数,也就是取最靠右的点,直到有一个区间被取完。
比如上图中,,, 两两为一组,每次分别被取出。
于是对于每个中位数,只需要 枚举两侧添加点的个数,就可以确定出选取的情况,求出平均值与中位数的差。
由于方案已经确定,直接输出即可。
这样可以得到整体复杂度为 的算法,期望得分未知。
子集大小为偶数#
结论#
对于大小为偶数的子集,观察可以得到结论:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优,下面进行证明。
分析#
按照上面的贪心策略任取一个大小为奇数的子集,将其升序排序,设为 .
现在向子集中添加一个数,使其大小变为偶数。根据上面提到的贪心策略,新加入的数 有以下两种情况:
- 在 左侧,即 ,
- 紧邻中位数右侧,即 .
在上面的图里,可以形象地看到, 的第一种情况是选择 ,第二种情况是在 中选一个。
第一种情况,显然不会使答案变大,因为 对平均值的影响太大了。
对于第二种情况,下面进行详细证明。
证明#
为了方便,我们把加入 后的数列重新升序排序,不妨设为 .
加入 前的答案为:
加入后的答案为:
两者作差并化简后得到:
又因为 ,所以:
即:
代入 中得到:
于是我们证明了:加入 前的答案一定不小于加入 后的答案,也就是:对于任意大小为偶数的子集,一定存在大小为奇数的子集比其更优。
代码#
以下代码输出的是:平均值与中位数的差的最大值
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3f;
inline int read() {}//快读
int n, a[MAXN], sum[MAXN];
signed main() {
n = read();
for(int i = 1; i <= n; i++) a[i] = read();
sort(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
double ans = (-1.0) * inf;
for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右
double mmax = (double)(a[i]);//只取中位数
for(int j = 1; j <= min(i - 1, n - i); j++) {//subtask1 暴力枚举两侧加2*j个点
int add_l = sum[i - 1] - sum[i - j - 1];
int add_r = sum[n] - sum[n - j];
mmax = max(mmax, (double)(add_l + add_r + a[i]) / (double)(2 * j + 1));
}
ans = max(ans, mmax - (double)(a[i]));
}
printf("%.2lf\n", ans);
return 0;
}
cpp#
考虑如何把时间复杂度优化到 级别。
的做法中,我们先枚举中位数 ,再枚举两侧添加的元素个数 .
在此基础上,设第 次添加的两个元素下标分别为 和 .
当确定中位数时,答案只与平均值有关,考虑计算添加 对平均值的贡献。
贡献,也就是添加这两个数之后,平均值的变化量,若平均值变大,则贡献为正,反之为负。
显然地,随着 的增长, 会持续下降,对平均值的贡献也会越来越少。
我们最后一次选取的两个元素 ,必定满足贡献大于 ,并且是最接近 的。
因此我们可以通过二分,找到贡献紧邻最接近0的数。
这样就确定了选择哪些数,前缀和计算答案即可。
具体实现时,我们不需要准确地计算出每两个元素的贡献。若加入两个元素后的平均值大于原数列的平均值,则贡献为正,否则贡献为负。这部分需要利用前缀和求出数列平均值。
对于输出方案,已经确定,直接输出即可。
这样就得到了 的算法,期望得分 .
代码#
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3f;
inline int read() {
int ret = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch <= '9' && ch >= '0') {
ret = ret * 10 + ch - '0';
ch = getchar();
}
return ret;
}
int n;
int a[MAXN], sum[MAXN];
signed main() {
n = read();
for(int i = 1; i <= n; i++) a[i] = read();
sort(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i];
double ans = (-1.0) * inf; int ans_cnt = 0, ans_mid = 0;
for(int i = 1; i <= n; i++) {//枚举中位数,两侧同时添加点,尽可能靠右
int l = 1, r = min(i - 1, n - i), mid;
double mmax = (double)(a[i]); int cnt = 0;
while(l <= r) {
mid = (l + r) >> 1;
int sum1 = sum[i - 1] - sum[i - mid] + sum[n] - sum[n - mid + 1];
//j = mid - 1,即加入前的平均值
int sum2 = sum[i - 1] - sum[i - mid - 1] + sum[n] - sum[n - mid];
//j = mid,加入两个元素后的平均值
if((double)(sum1 + a[i]) / (double)(2 * mid - 1) < (double)(sum2 + a[i]) / (double)(2 * mid + 1)) {
if((double)(sum2 + a[i]) / (double)(2 * mid + 1) > mmax) {
mmax = (double)(sum2 + a[i]) / (double)(2 * mid + 1);
cnt = mid;
}
l = mid + 1;
} else {
r = mid - 1;
}
}
if(mmax - double(a[i]) > ans) {
ans = mmax - (double)(a[i]);
ans_cnt = cnt;
ans_mid = i;
}
}
cout << ans_cnt * 2 + 1 << endl;
for(int i = ans_mid - ans_cnt; i <= ans_mid; i++) cout << a[i] << " ";
for(int i = n - ans_cnt + 1; i <= n; i++) cout << a[i] << " ";
return 0;
}```
cpp