Processing math: 100%
(Analysis by Larry Xing, Benjamin Qi)

Subtask 1:

Consider solving the subproblem for the whole array.

First suppose N is odd. Then if we try a few examples, it seems that the answer is almost always the maximum over the array. Specifically, it holds for N5. Proof sketch:

  1. By the 0-1 principle, it suffices to prove the conclusion for arrays consisting only of zeros and ones.
  2. First, we can use brute force to show the conclusion for 0-1 arrays of length N=5 (that is, for all arrays with at least two ones the answer is one).
  3. Then use induction to show the conclusion for 0-1 arrays of length N>5 where N is odd. Specifically, when N is odd and N>5, we can always select two consecutive zeros in the first min operation, and then any two elements in the max operation, to reduce to the case N2.

Now, suppose N is even. Then, it seems that the answer is almost always the second maximum over the array. It is clear you can not do better; the last operation is always a min operation, so you can not achieve the maximum. We can prove this for N8 similarly as above.

Back to our original problem, we can simply iterate over all N(N+1)2 subarrays.

  1. If N is odd and N5, then we can directly compute the max, and similarly for N is even and N8.
  2. Otherwise, we can try every possible combination of operations to find the best one.

Depending on the implementation, this can be O(N3) or O(N3logN).

Subtask 2:

From the previous subtask, we have reduced the problem to finding the sum of the maximums of all odd length subarrays and the sum of the second maximums of all even length subarrays.

To do this, let's fix the left endpoint of our subarray, and iterate over the right endpoint. We can keep track of the running maximum and second maximum, so when we compute the answer for a specific interval, we can simply check whether the length is even or odd and add the corresponding value to our answer.

This gives us an O(N2) solution.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define ff first
#define ss second
const int MAX_N = 1e6+5;
const ll INF = 1e18;

int n;
ll a[MAX_N];
bitset<64> dp;

ll solve(int k, int i){
    vector<int> v(k);
    int mx = 0, smx = 0;
    for(int j = 0; j < k; j++){
        v[j] = a[i-j];
        if(v[j] > mx) smx = mx, mx = v[j];
        else if(v[j] > smx) smx = v[j];
    }
    dp.reset();
    dp[(1 << k)-1] = 1;
    int ans = 0;
    for(int i = (1 << k)-1; i > 0; i--){
        if(!dp[i]) continue;
        if(__builtin_popcount(i) == 1){
            for(int j = 0; j < k; j++)
                if(i&(1 << j))
                    ans = max(ans, v[j]);
            continue;
        }
        int x = k-__builtin_popcount(i), lst = -1;
        for(int j = 0; j < k; j++){
            if(!(i&(1 << j))) continue;
            if(~lst)
                dp[(x&1)^(v[lst] > v[j]) ? i^(1 << lst) : i^(1 << j)] = 1;
            lst = j;
        }
    }
    return ans-(k&1?mx:smx);
}

int main(int argc, const char * argv[]){
    cin.tie(0)->sync_with_stdio(0);
    cin.exceptions(cin.failbit);
    cin >> n;
    for(int i = 0; i < n; i++) cin >> a[i];
    ll ans = 0;
    for(int i = 0; i < n; i++){
        ll mx = 0, smx = 0;
        for(int j = i; j >= 0; j--){
            if(a[j] > mx) smx = mx, mx = a[j];
            else if(a[j] > smx) smx = a[j];
            ans += ((j-i)&1 ? smx : mx);
        }
        if(i >= 2) ans += solve(3, i);
        if(i >= 3) ans += solve(4, i);
        if(i >= 5) ans += solve(6, i);
    }
    cout << ans << '\n';
    return 0;
}

Full Solution:

For convenience let's assume all ai are distinct (if they are not, break ties arbitrarily).

One approach to optimizing the O(N2) portion is to consider the number of segments for which ai is an answer, for each i. Let's initialize each index i as "active" and then iterate over each ai in increasing order of value, deactivating them as we go while maintaining the neighboring active indices to both the left and the right using a doubly linked list. The subarrays with answer ai are those

  1. Of odd length, containing i, and no other active indices.
  2. Of even length, containing i, and exactly one other active index.

This gives us an O(N)-time solution.

Benjamin Qi's code:

#include <bits/stdc++.h>
using namespace std;

template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)

using ll = long long;
using vi = V<int>;

map<pair<vi, int>, int> mem;

template <class T> void ckmax(T &a, T b) { a = max(a, b); }
template <class T> void ckmin(T &a, T b) { a = min(a, b); }

int solve(vi v, int parity) {
    assert(!v.empty());
    if (v.size() == 1) return v.at(0);
    if (mem.count({v, parity})) return mem.at({v, parity});
    int res = 0;
    if (parity == 0) {
        for (int i = 0; i + 1 < size(v); ++i) {
            auto nv = v;
            ckmin(nv.at(i), nv.at(i + 1));
            nv.erase(begin(nv) + i + 1);
            ckmax(res, solve(nv, parity ^ 1));
        }
    } else {
        for (int i = 0; i + 1 < size(v); ++i) {
            auto nv = v;
            ckmax(nv.at(i), nv.at(i + 1));
            nv.erase(begin(nv) + i + 1);
            ckmax(res, solve(nv, parity ^ 1));
        }
    }
    return mem[{v, parity}] = res;
}

int solve_fast(vi v) {
    sort(all(v));
    if (v.size() & 1) return end(v)[-1];
    return end(v)[-2];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    cin >> N;
    vi A(N + 1);
    for (int i = 1; i <= N; ++i) cin >> A.at(i);
    V<vi> with_val(N + 1);
    for (int i = 1; i <= N; ++i) with_val.at(A.at(i)).push_back(i);

    vi l(N + 2), r(N + 2);
    for (int i = 0; i <= N + 1; ++i) {
        // closest active neighbors to left and right
        l.at(i) = i - 1;
        r.at(i) = i + 1;
    }
    auto count_range_upto = [&](int x, int p) -> ll {
        assert(x >= 0);
        if (p == 0) return x / 2;
        return (x + 1) / 2;
    };
    auto count_range = [&](int lo, int hi,
                           int p) -> ll {  // num ints in [lo, hi] with parity p
        return count_range_upto(hi, p) - count_range_upto(lo - 1, p);
    };
    auto count_odd_first = [&](int m) {  // odd ranges with answer A[m]
        ll res = 0;
        for (int p : {0, 1})
            res +=
                count_range(l.at(m) + 1, m, p) * count_range(m, r.at(m) - 1, p);
        return res;
    };
    auto count_even_pair = [&](int a, int b) -> ll {
        if (a <= 0 || b >= N + 1) return 0;
        ll res = 0;
        for (int p : {0, 1})
            res += count_range(l.at(a) + 1, a, p) *
                   count_range(b, r.at(b) - 1, p ^ 1);
        return res;
    };
    auto count_even_second = [&](int i) {  // even ranges with answer A[i]
        return count_even_pair(l.at(i), i) + count_even_pair(i, r.at(i));
    };
    ll ans = 0;
    for (int v = 1; v <= N; ++v) {
        for (int i : with_val.at(v)) {
            // count subarrays with (tentative) ans A[i]
            ans += (ll)v * (count_odd_first(i) + count_even_second(i));
            // deactivate i
            int lo = l.at(i), hi = r.at(i);
            r.at(lo) = hi;
            l.at(hi) = lo;
        }
    }
    assert(r.at(0) == N + 1);
    // corrections for small lengths
    for (int len : {3, 4, 6}) {
        vector<bool> precalced(1 << len);
        for (int i = 0; i < (1 << len); ++i) {
            vi v;
            for (int j = 0; j < len; ++j) {
                if (i & (1 << j)) v.push_back(1);
                else v.push_back(0);
            }
            precalced.at(i) = solve(v, 0);
        }
        auto solve_precalced =
            [&](const vi &v) {  // solve for v using precalculated answers for
                                // 0-1 arrays
                assert(v.size() == len);
                vi by_val(len);
                iota(all(by_val), 0);
                sort(all(by_val), [&](int x, int y) {
                    return v.at(x) > v.at(y);
                });  // sort descending
                int mask = 0;
                for (int i : by_val) {
                    mask ^= 1 << i;
                    if (precalced.at(mask)) return v.at(i);
                }
                assert(false);
            };
        for (int i = 0; i <= N - len; ++i) {
            vi v(begin(A) + i + 1, begin(A) + i + 1 + len);
            ans -= solve_fast(v);
            ans += solve_precalced(v);
        }
    }
    cout << ans << "\n";
}

Alternative Solution:

An alternative approach is to directly optimize our approach in subtask 2. For two consecutive left endpoints, the set of running maximums and second maximums does not change much between each other.

This inspires us to keep track of a vector of ``segments" where the maximum and second maximum are the same. To move from one left endpoint to the next, we replace the suffix in which the new element would change either the max or second max, and recalculate the new segments to insert into the vector.

To analyse the time complexity, consider the potential function

Φ=2(# distinct maxes)+(# distinct second maxes that are not maxes)

The amount of work that we do in each step is equivalent to the number of segments we insert at that step.

Consider the segments we insert where al is the max. Here, we are essentially modifying each of the segments such that the max becomes a second max. Analyzing the change to the potential function, we see that the addition of al increases the potential function by 2, while the modifications decrease the potential function by at least # max segments inserted.

In addition, there is at most one segment we insert where al is the second max. This is because if there were two separate such segments, one of those intervals would contain two elements bigger than al, which makes al not the second max. Thus, we insert at most one such segment, while increasing the potential function by at most 1.

In total, the number of segments that we can insert is at most 1+# max segments, inserted, so

work+ΔΦ# segments inserted# max segments inserted+34

so

total work4NΦf+Φi4N
and our time complexity is O(N).

Larry Xing's code:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define ff first
#define ss second
const int MAX_N = 1e6+5;
const ll INF = 1e18;

int n;
ll a[MAX_N], nums[7][64];
ll mx[2], smx[2];
bitset<64> dp;
vector<pair<pll, ll>> v[2];

void mod(ll x, int t){
    vector<pair<pll, ll>> v2;
    while(v[t].size() && x >= v[t].back().ff.ss){
        if(!v2.size() || v2.back().ff != pll(max(x, v[t].back().ff.ff), min(x, v[t].back().ff.ff)))
            v2.push_back({{max(x, v[t].back().ff.ff), min(x, v[t].back().ff.ff)}, v[t].back().ss});
        else v2.back().ss += v[t].back().ss;
        mx[t] -= v[t].back().ff.ff*v[t].back().ss;
        smx[t] -= v[t].back().ff.ss*v[t].back().ss;
        v[t].pop_back();
    }
    while(v2.size()){
        mx[t] += v2.back().ff.ff*v2.back().ss;
        smx[t] += v2.back().ff.ss*v2.back().ss;
        v[t].push_back(v2.back()), v2.pop_back();
    }
}

ll solve(int k, int num){
    vector<int> v(k);
    for(int i = 0; i < k; i++) v[i] = (num&(1 << i)) >> i;
    dp.reset();
    dp[(1 << k)-1] = 1;
    int ans = 0;
    for(int i = (1 << k)-1; i > 0; i--){
        if(!dp[i]) continue;
        if(__builtin_popcount(i) == 1){
            for(int j = 0; j < k; j++)
                if(i&(1 << j))
                    ans = max(ans, v[j]);
            continue;
        }
        int x = k-__builtin_popcount(i), lst = -1;
        for(int j = 0; j < k; j++){
            if(!(i&(1 << j))) continue;
            if(~lst)
                dp[(x&1)^(v[lst] > v[j]) ? i^(1 << lst) : i^(1 << j)] = 1;
            lst = j;
        }
    }
    return ans;
}

int main(int argc, const char * argv[]){
    cin.tie(0)->sync_with_stdio(0);
    cin.exceptions(cin.failbit);
    cin >> n;
    for(int i = 0; i < n; i++) cin >> a[i];
    ll ans = 0;
    for(int i = 0; i < n; i++){
        mod(a[i], i&1);
        mod(a[i], 1^i&1);
        v[i&1].push_back({{a[i], -1}, 1});
        mx[i&1] += a[i], smx[i&1] -= 1;
        ans += mx[i&1] + smx[1^i&1];
    }
    for(int k : {3, 4, 6})
        for(int i = 0; i < (1 << k); i++)
            nums[k][i] = solve(k, i);
    for(int k : {3, 4, 6}){
        for(int i = k-1; i < n; i++){
            int mx = 0, smx = 0, tmx = 0, num = 0;
            for(int j = 0; j < k; j++){
                if(a[i-j] > mx) tmx = smx, smx = mx, mx = a[i-j];
                else if(a[i-j] > smx) tmx = smx, smx = a[i-j];
                else if(a[i-j] > tmx) tmx = a[i-j];
            }
            for(int j = 0; j < k; j++) num |= (a[i-j] >= (k&1 ? mx : smx)) << j;
            ans += nums[k][num] ? 0 : (k&1 ? smx-mx : tmx-smx);
        }
    }
    cout << ans << '\n';
    return 0;
}