Consider solving the subproblem for the whole array.
First suppose N is odd. Then if we try a few examples, it seems that the answer is almost always the maximum over the array. Specifically, it holds for N≥5. Proof sketch:
Now, suppose N is even. Then, it seems that the answer is almost always the second maximum over the array. It is clear you can not do better; the last operation is always a min operation, so you can not achieve the maximum. We can prove this for N≥8 similarly as above.
Back to our original problem, we can simply iterate over all N(N+1)2 subarrays.
Depending on the implementation, this can be O(N3) or O(N3logN).
To do this, let's fix the left endpoint of our subarray, and iterate over the right endpoint. We can keep track of the running maximum and second maximum, so when we compute the answer for a specific interval, we can simply check whether the length is even or odd and add the corresponding value to our answer.
This gives us an O(N2) solution.
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> pii; typedef pair<ll, ll> pll; #define ff first #define ss second const int MAX_N = 1e6+5; const ll INF = 1e18; int n; ll a[MAX_N]; bitset<64> dp; ll solve(int k, int i){ vector<int> v(k); int mx = 0, smx = 0; for(int j = 0; j < k; j++){ v[j] = a[i-j]; if(v[j] > mx) smx = mx, mx = v[j]; else if(v[j] > smx) smx = v[j]; } dp.reset(); dp[(1 << k)-1] = 1; int ans = 0; for(int i = (1 << k)-1; i > 0; i--){ if(!dp[i]) continue; if(__builtin_popcount(i) == 1){ for(int j = 0; j < k; j++) if(i&(1 << j)) ans = max(ans, v[j]); continue; } int x = k-__builtin_popcount(i), lst = -1; for(int j = 0; j < k; j++){ if(!(i&(1 << j))) continue; if(~lst) dp[(x&1)^(v[lst] > v[j]) ? i^(1 << lst) : i^(1 << j)] = 1; lst = j; } } return ans-(k&1?mx:smx); } int main(int argc, const char * argv[]){ cin.tie(0)->sync_with_stdio(0); cin.exceptions(cin.failbit); cin >> n; for(int i = 0; i < n; i++) cin >> a[i]; ll ans = 0; for(int i = 0; i < n; i++){ ll mx = 0, smx = 0; for(int j = i; j >= 0; j--){ if(a[j] > mx) smx = mx, mx = a[j]; else if(a[j] > smx) smx = a[j]; ans += ((j-i)&1 ? smx : mx); } if(i >= 2) ans += solve(3, i); if(i >= 3) ans += solve(4, i); if(i >= 5) ans += solve(6, i); } cout << ans << '\n'; return 0; }
One approach to optimizing the O(N2) portion is to consider the number of segments for which ai is an answer, for each i. Let's initialize each index i as "active" and then iterate over each ai in increasing order of value, deactivating them as we go while maintaining the neighboring active indices to both the left and the right using a doubly linked list. The subarrays with answer ai are those
This gives us an O(N)-time solution.
Benjamin Qi's code:
#include <bits/stdc++.h> using namespace std; template <class T> using V = vector<T>; #define all(x) begin(x), end(x) using ll = long long; using vi = V<int>; map<pair<vi, int>, int> mem; template <class T> void ckmax(T &a, T b) { a = max(a, b); } template <class T> void ckmin(T &a, T b) { a = min(a, b); } int solve(vi v, int parity) { assert(!v.empty()); if (v.size() == 1) return v.at(0); if (mem.count({v, parity})) return mem.at({v, parity}); int res = 0; if (parity == 0) { for (int i = 0; i + 1 < size(v); ++i) { auto nv = v; ckmin(nv.at(i), nv.at(i + 1)); nv.erase(begin(nv) + i + 1); ckmax(res, solve(nv, parity ^ 1)); } } else { for (int i = 0; i + 1 < size(v); ++i) { auto nv = v; ckmax(nv.at(i), nv.at(i + 1)); nv.erase(begin(nv) + i + 1); ckmax(res, solve(nv, parity ^ 1)); } } return mem[{v, parity}] = res; } int solve_fast(vi v) { sort(all(v)); if (v.size() & 1) return end(v)[-1]; return end(v)[-2]; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int N; cin >> N; vi A(N + 1); for (int i = 1; i <= N; ++i) cin >> A.at(i); V<vi> with_val(N + 1); for (int i = 1; i <= N; ++i) with_val.at(A.at(i)).push_back(i); vi l(N + 2), r(N + 2); for (int i = 0; i <= N + 1; ++i) { // closest active neighbors to left and right l.at(i) = i - 1; r.at(i) = i + 1; } auto count_range_upto = [&](int x, int p) -> ll { assert(x >= 0); if (p == 0) return x / 2; return (x + 1) / 2; }; auto count_range = [&](int lo, int hi, int p) -> ll { // num ints in [lo, hi] with parity p return count_range_upto(hi, p) - count_range_upto(lo - 1, p); }; auto count_odd_first = [&](int m) { // odd ranges with answer A[m] ll res = 0; for (int p : {0, 1}) res += count_range(l.at(m) + 1, m, p) * count_range(m, r.at(m) - 1, p); return res; }; auto count_even_pair = [&](int a, int b) -> ll { if (a <= 0 || b >= N + 1) return 0; ll res = 0; for (int p : {0, 1}) res += count_range(l.at(a) + 1, a, p) * count_range(b, r.at(b) - 1, p ^ 1); return res; }; auto count_even_second = [&](int i) { // even ranges with answer A[i] return count_even_pair(l.at(i), i) + count_even_pair(i, r.at(i)); }; ll ans = 0; for (int v = 1; v <= N; ++v) { for (int i : with_val.at(v)) { // count subarrays with (tentative) ans A[i] ans += (ll)v * (count_odd_first(i) + count_even_second(i)); // deactivate i int lo = l.at(i), hi = r.at(i); r.at(lo) = hi; l.at(hi) = lo; } } assert(r.at(0) == N + 1); // corrections for small lengths for (int len : {3, 4, 6}) { vector<bool> precalced(1 << len); for (int i = 0; i < (1 << len); ++i) { vi v; for (int j = 0; j < len; ++j) { if (i & (1 << j)) v.push_back(1); else v.push_back(0); } precalced.at(i) = solve(v, 0); } auto solve_precalced = [&](const vi &v) { // solve for v using precalculated answers for // 0-1 arrays assert(v.size() == len); vi by_val(len); iota(all(by_val), 0); sort(all(by_val), [&](int x, int y) { return v.at(x) > v.at(y); }); // sort descending int mask = 0; for (int i : by_val) { mask ^= 1 << i; if (precalced.at(mask)) return v.at(i); } assert(false); }; for (int i = 0; i <= N - len; ++i) { vi v(begin(A) + i + 1, begin(A) + i + 1 + len); ans -= solve_fast(v); ans += solve_precalced(v); } } cout << ans << "\n"; }
An alternative approach is to directly optimize our approach in subtask 2. For two consecutive left endpoints, the set of running maximums and second maximums does not change much between each other.
This inspires us to keep track of a vector of ``segments" where the maximum and second maximum are the same. To move from one left endpoint to the next, we replace the suffix in which the new element would change either the max or second max, and recalculate the new segments to insert into the vector.
To analyse the time complexity, consider the potential function
The amount of work that we do in each step is equivalent to the number of segments we insert at that step.
Consider the segments we insert where al is the max. Here, we are essentially modifying each of the segments such that the max becomes a second max. Analyzing the change to the potential function, we see that the addition of al increases the potential function by 2, while the modifications decrease the potential function by at least # max segments inserted.
In addition, there is at most one segment we insert where al is the second max. This is because if there were two separate such segments, one of those intervals would contain two elements bigger than al, which makes al not the second max. Thus, we insert at most one such segment, while increasing the potential function by at most 1.
In total, the number of segments that we can insert is at most 1+# max segments, inserted, so
so
Larry Xing's code:
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> pii; typedef pair<ll, ll> pll; #define ff first #define ss second const int MAX_N = 1e6+5; const ll INF = 1e18; int n; ll a[MAX_N], nums[7][64]; ll mx[2], smx[2]; bitset<64> dp; vector<pair<pll, ll>> v[2]; void mod(ll x, int t){ vector<pair<pll, ll>> v2; while(v[t].size() && x >= v[t].back().ff.ss){ if(!v2.size() || v2.back().ff != pll(max(x, v[t].back().ff.ff), min(x, v[t].back().ff.ff))) v2.push_back({{max(x, v[t].back().ff.ff), min(x, v[t].back().ff.ff)}, v[t].back().ss}); else v2.back().ss += v[t].back().ss; mx[t] -= v[t].back().ff.ff*v[t].back().ss; smx[t] -= v[t].back().ff.ss*v[t].back().ss; v[t].pop_back(); } while(v2.size()){ mx[t] += v2.back().ff.ff*v2.back().ss; smx[t] += v2.back().ff.ss*v2.back().ss; v[t].push_back(v2.back()), v2.pop_back(); } } ll solve(int k, int num){ vector<int> v(k); for(int i = 0; i < k; i++) v[i] = (num&(1 << i)) >> i; dp.reset(); dp[(1 << k)-1] = 1; int ans = 0; for(int i = (1 << k)-1; i > 0; i--){ if(!dp[i]) continue; if(__builtin_popcount(i) == 1){ for(int j = 0; j < k; j++) if(i&(1 << j)) ans = max(ans, v[j]); continue; } int x = k-__builtin_popcount(i), lst = -1; for(int j = 0; j < k; j++){ if(!(i&(1 << j))) continue; if(~lst) dp[(x&1)^(v[lst] > v[j]) ? i^(1 << lst) : i^(1 << j)] = 1; lst = j; } } return ans; } int main(int argc, const char * argv[]){ cin.tie(0)->sync_with_stdio(0); cin.exceptions(cin.failbit); cin >> n; for(int i = 0; i < n; i++) cin >> a[i]; ll ans = 0; for(int i = 0; i < n; i++){ mod(a[i], i&1); mod(a[i], 1^i&1); v[i&1].push_back({{a[i], -1}, 1}); mx[i&1] += a[i], smx[i&1] -= 1; ans += mx[i&1] + smx[1^i&1]; } for(int k : {3, 4, 6}) for(int i = 0; i < (1 << k); i++) nums[k][i] = solve(k, i); for(int k : {3, 4, 6}){ for(int i = k-1; i < n; i++){ int mx = 0, smx = 0, tmx = 0, num = 0; for(int j = 0; j < k; j++){ if(a[i-j] > mx) tmx = smx, smx = mx, mx = a[i-j]; else if(a[i-j] > smx) tmx = smx, smx = a[i-j]; else if(a[i-j] > tmx) tmx = a[i-j]; } for(int j = 0; j < k; j++) num |= (a[i-j] >= (k&1 ? mx : smx)) << j; ans += nums[k][num] ? 0 : (k&1 ? smx-mx : tmx-smx); } } cout << ans << '\n'; return 0; }