I will refer to contiguous sequences as intervals. Define the value of an interval to be the minimum possible final number it can be converted into.
Subtask 1: Similar to 248, we can apply dynamic programming on ranges. Specifically, if $dp[i][j]$ denotes the value of interval $i\ldots j$, then
The time complexity is $O(N^3)$.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
template <class T> void ckmin(T &a, const T &b) { a = min(a, b); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> dp(N, V<int>(N));
for (int i = N - 1; i >= 0; --i) {
dp.at(i).at(i) = A.at(i);
for (int j = i + 1; j < N; ++j) {
dp[i][j] = INT_MAX;
for (int k = i; k < j; ++k) {
ckmin(dp.at(i).at(j),
max(dp.at(i).at(k), dp.at(k + 1).at(j)) + 1);
}
}
}
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
for (int j = i; j < N; ++j) {
ans += dp.at(i).at(j);
}
}
cout << ans << "\n";
}
Subtask 2: We can optimize the solution above by more quickly finding the maximum $k'$ such that $dp[i][k'] \le dp[k'+1][j]$. Then we only need to consider $k\in \{k',k'+1\}$ when computing $dp[i][j]$. Using the observation that $k'$ does not decrease as $j$ increases and $i$ is held fixed leads to a solution in $O(N^2)$:
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
template <class T> void ckmin(T &a, const T &b) { a = min(a, b); }
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> dp(N, V<int>(N));
for (int i = N - 1; i >= 0; --i) {
dp.at(i).at(i) = A.at(i);
int k = i - 1;
for (int j = i + 1; j < N; ++j) {
while (k + 1 < j && dp.at(i).at(k + 1) <= dp.at(k + 2).at(j)) ++k;
dp[i][j] = INT_MAX;
ckmin(dp[i][j], dp.at(k + 1).at(j));
ckmin(dp[i][j], dp.at(i).at(k + 1));
++dp[i][j];
}
}
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
for (int j = i; j < N; ++j) {
ans += dp.at(i).at(j);
}
}
cout << ans << "\n";
}
Alternatively, finding $k'$ with binary search leads to a solution in $O(N^2\log N)$.
Subtask 3: Similar to 262144, we can use binary lifting.
For each of $v=1,2,3,\ldots$, I count the number of intervals with value at least $v$. The answer is the sum of this quantity over all $v$.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
for (int &x : A) cin >> x;
V<V<int>> with_val;
for (int i = 0; i < N; ++i) {
while (size(with_val) <= A[i]) with_val.emplace_back();
with_val.at(A[i]).push_back(i);
}
V<int> nex(N + 1);
iota(all(nex), 0);
int64_t ans = 0;
for (int v = 1;; ++v) {
if (nex[0] == N) {
cout << ans << "\n";
exit(0);
}
// add all intervals with value >= v
for (int i = 0; i <= N; ++i) ans += N - nex[i];
for (int i = 0; i <= N; ++i) nex[i] = nex[nex[i]];
if (v < size(with_val)) {
for (int i : with_val.at(v)) {
nex[i] = i + 1;
}
}
}
}
Subtask 4: For each $i$ from $1$ to $N$ in increasing order, consider the values of all intervals with right endpoint $i$. Note that the value $v$ of each such interval must satisfy $v\in [A_i, A_i+\lceil \log_2 i\rceil]$ due to $A$ being sorted. Thus, it suffices to be able to compute for each $v$ the minimum $l$ such that $dp[l][i]\le v$. To do this, we maintain a partition of $1\ldots i$ into contiguous subsequences such that every contiguous subsequence has value at most $A_i$ and is leftwards-maximal (extending any subsequence one element to the left would cause its value to exceed $A_i$). When transitioning from $i-1$ to $i$, we merge every two consecutive contiguous subsequences $A_i-A_{i-1}$ times and then add contiguous subsequence $[i,i]$ to the end of the partition.
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
vector<int> A(N);
for (int &x : A) cin >> x;
assert(is_sorted(all(A)));
// left endpoints of each partition interval in decreasing order
deque<int> left_ends;
int64_t ans = 0;
for (int i = 0; i < N; ++i) {
if (i) {
for (int v = A[i - 1]; v < A[i]; ++v) {
if (size(left_ends) == 1) break;
// merge every two consecutive intervals in partition
deque<int> n_left_ends;
for (int j = 0; j < (int)size(left_ends); ++j) {
if ((j & 1) || j + 1 == (int)size(left_ends)) {
n_left_ends.push_back(left_ends[j]);
}
}
swap(left_ends, n_left_ends);
}
}
left_ends.push_front(i); // add [i,i] to partition
int L = i + 1;
for (int v = A[i]; L; ++v) {
int next_L = left_ends.at(
min((int)size(left_ends) - 1, (1 << (v - A[i])) - 1));
ans += (int64_t)(L - next_L) * v;
L = next_L;
}
}
cout << ans << "\n";
}
Full Credit: Call an interval relevant if it is not possible to extend it to the left or to the right without increasing its value.
Claim: The number of relevant intervals is $O(N\log N)$.
Proof: See the end of the analysis.
We'll compute the same quantities as in subtask 3, but this time, we'll transition from $v-1$ to $v$ in time proportional to the number of relevant intervals with value $v-1$ plus the number of relevant intervals with value $v$, this will give us a solution in $O(N\log N)$.
For a fixed value of $v$, say that an interval $[l,r)$ is a segment if it contains no value greater than $v$, and $\min(A_{l-1},A_r)>v$. Say that an interval is maximal with respect to $v$ if it has value at most $v$ and extending it to the left or right would cause its value to exceed $v$. Note that a maximal interval $[l,r)$ must be relevant, and it must either have value equal to $v$ or be a segment.
My code follows. $\texttt{ivals}[i]$ stores all maximal intervals contained within the segment with left endpoint $i$. The following steps are used to transition from $v-1$ to $v$:
#include <bits/stdc++.h>
using namespace std;
template <class T> using V = vector<T>;
#define all(x) begin(x), end(x)
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
V<int> A(N);
V<V<int>> with_A;
for (int i = 0; i < N; ++i) {
cin >> A[i];
while ((int)size(with_A) <= A[i]) with_A.emplace_back();
with_A[A[i]].push_back(i);
}
// sum(l ... r)
auto sum_arith = [&](int64_t l, int64_t r) {
return (r + l) * (r - l + 1) / 2;
};
// total number of intervals covered by list of maximal intervals
auto contrib = [&](const list<pair<int, int>> &L) {
int64_t ret = 0;
for (auto it = begin(L);; ++it) {
auto [x, y] = *it;
if (next(it) == end(L)) {
ret += sum_arith(0, y - x);
break;
} else {
int next_x = next(it)->first;
ret += int64_t(next_x - x) * y - sum_arith(x, next_x - 1);
}
}
return ret;
};
int64_t num_at_least = (int64_t)N * (N + 1) / 2;
auto halve = [&](list<pair<int, int>> &L) {
if (size(L) <= 1) return;
num_at_least += contrib(L);
int max_so_far = -1;
list<pair<int, int>> n_L;
auto it = begin(L);
for (auto [x, y] : L) {
while (next(it) != end(L) && next(it)->first <= y) ++it;
if (it->second > max_so_far) {
n_L.push_back({x, max_so_far = it->second});
}
}
swap(L, n_L);
num_at_least -= contrib(L);
};
// doubly linked list to maintain segments
V<int> pre(N + 1);
iota(all(pre), 0);
V<int> nex = pre;
int64_t ans = 0;
V<int> active; // active segments
// maximal intervals within each segment
V<list<pair<int, int>>> ivals(N + 1);
for (int v = 1; num_at_least; ++v) {
ans += num_at_least; // # intervals with value >= v
V<int> n_active;
for (int l : active) {
halve(ivals[l]);
if (size(ivals[l]) > 1) n_active.push_back(l);
}
if (v < (int)size(with_A)) {
for (int i : with_A[v]) {
int l = pre[i], r = nex[i + 1];
bool should_add = size(ivals[l]) <= 1;
pre[i] = nex[i + 1] = -1;
nex[l] = r, pre[r] = l;
ivals[l].push_back({i, i + 1});
--num_at_least;
ivals[l].splice(end(ivals[l]), ivals[i + 1]);
should_add &= size(ivals[l]) > 1;
if (should_add) n_active.push_back(l);
}
}
swap(active, n_active);
}
cout << ans << "\n";
}
Proof of Claim: Let $f(N)$ denote the maximum possible number of relevant subarrays for a sequence of size $N$. We can show that $f(N)\le O(\log N!)\le O(N\log N)$. This upper bound can be attained when all elements of the input sequence are equal.
The idea is to consider a Cartesian tree of $a$. Specifically, suppose that one of the maximum elements of $a$ is located at position $p$ ($1\le p\le N$). Then
WLOG we may assume $2p\le N$.
Lemma:
Proof of Lemma: We can in fact show that
The $p$ comes from all relevant intervals with a fixed value having distinct left endpoints and the $2^k$ comes from the fact that to generate a relevant interval of value $a_p+k+1$ containing $p$, you must start with a relevant interval of value $a_p+k$ and choose to extend it either to the right or to the left.
To finish, observe that the summation $\sum_{k=0}^{\left\lceil\log_2N\right\rceil}\#(\text{relevant intervals containing }p\text{ with value }a_p+k)$ is dominated by the terms satisfying $\log_2p\le k\le \left\lceil\log_2N\right\rceil$. $\blacksquare$
Since $O(p\log \frac{N}{p}) \le O\left(\log \binom{N}{p}\right)\le O(\log \frac{N!}{(p-1)!(N-p)!})$, the claim follows from the lemma by induction:
Here is an alternative approach by Danny Mittal that uses both the idea from the subtask 4 solution and the Cartesian tree. He repeatedly finds the index of the rightmost maximum $a_{mid}$ of the input sequence, solves the problem recursively on $a_{1\ldots mid-1}$ and $a_{mid+1 \ldots N}$, and then adds the contribution of all intervals containing $a_{mid}$. This also runs in $O(N\log N)$.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Revisited262144Array {
static int n;
static int[] xs;
static int[] left;
static int[] right;
static int[] forward;
static int[] reverse;
static long answer = 0;
public static int reduceForward(int start, int length, int lgFactor) {
if (lgFactor == 0) {
return length;
}
int factor = 1 << Math.min(lgFactor, 20);
int j = start;
for (int k = start + factor - 1; k < start + length - 1; k += factor) {
forward[j++] = forward[k];
}
forward[j++] = forward[start + length - 1];
return j - start;
}
public static void reduceReverse(int start, int length, int lgFactor) {
if (lgFactor == 0) {
return;
}
int factor = 1 << Math.min(lgFactor, 20);
if (length > factor) {
int j = start + 1;
for (int k = start + 1 + ((length - factor - 1) % factor); k < start + length; k += factor) {
reverse[j++] = reverse[k];
}
}
}
public static int funStuff(int from, int mid, int to, int riseTo) {
if (from > to) {
return 0;
}
int leftLength = funStuff(from, left[mid], mid - 1, xs[mid]);
int rightLength = funStuff(mid + 1, right[mid], to, xs[mid]);
int leftStart = from;
int rightStart = mid + 1;
int last = mid - 1;
for (int j = 1; j <= rightLength + 1; j++) {
int frontier = j > 1 ? forward[rightStart + (j - 2)] : mid;
long weight = frontier - last;
last = frontier;
int lastInside = mid + 1;
int leftLast = leftLength == 0 ? mid : reverse[leftStart];
for (int d = 0; d <= 18 && lastInside > leftLast; d++) {
if (1 << d >= j) {
int frontierInside;
if (1 << d == j) {
frontierInside = mid;
} else if (1 << d <= j + leftLength) {
frontierInside = reverse[leftStart + leftLength + j - (1 << d)];
} else {
frontierInside = leftLast;
}
long weightInside = lastInside - frontierInside;
lastInside = frontierInside;
answer += weight * weightInside * ((long) (xs[mid] + d));
}
}
}
forward[leftStart + leftLength] = mid;
System.arraycopy(forward, rightStart, forward, leftStart + leftLength + 1, rightLength);
reverse[leftStart + leftLength] = mid;
System.arraycopy(reverse, rightStart, reverse, leftStart + leftLength + 1, rightLength);
int length = reduceForward(leftStart, leftLength + 1 + rightLength, riseTo - xs[mid]);
reduceReverse(leftStart, leftLength + 1 + rightLength, riseTo - xs[mid]);
return length;
}
public static void main(String[] args) throws IOException {
BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
n = Integer.parseInt(in.readLine());
StringTokenizer tokenizer = new StringTokenizer(in.readLine());
xs = new int[n];
for (int j = 0; j < n; j++) {
xs[j] = Integer.parseInt(tokenizer.nextToken());
}
left = new int[n];
right = new int[n];
ArrayDeque<Integer> stack = new ArrayDeque<>();
for (int j = 0; j < n; j++) {
while (!stack.isEmpty() && xs[stack.peek()] <= xs[j]) {
left[j] = stack.pop();
}
if (!stack.isEmpty()) {
right[stack.peek()] = j;
}
stack.push(j);
}
while (stack.size() > 1) {
stack.pop();
}
forward = new int[n];
reverse = new int[n];
funStuff(0, stack.pop(), n - 1, Integer.MAX_VALUE);
System.out.println(answer);
}
}