Note that it is optimal for Farmer John to milk his cows such that the cow with $i$th smallest milk production value spends $i$ minutes on the milking machine. Indeed, if there are two cows $i$ and $j$ such that $a_i > a_j$ but cow $i$ spends less minutes than cow $j$ on the milking machine, the total amount of milk Farmer John produces could be increased by swapping the amount of time cows $i$ and $j$ spend on the milking machine.
So the maximum amount of milk Farmer John can produce is $G(a) = \sum_{i=1}^{n} i \cdot a'_i$, where $a'$ is the array that results upon sorting $a$.
Let's first suppose that $a$ is originally sorted, and let the value of $G(a)$ initially be $S$.
Now let's see what happens when we replace some $a_i$ with some other value $v$. First suppose $v \geq a_i$. Then, in the sorted version of $a$, $v$ will belong in some position $p \geq i$, which we can find by binary search. Furthermore. all numbers originally in positions $i+1, i+2, \dots, p$ will shift over down to one position to positions $i, i+1, \dots, p-1$. As a result, $G(a)$ will now become $S - i \cdot a_i - \sum_{j=i+1}^{p} a_j + p \cdot v$. We can use prefix sums to compute $\sum_{j=i+1}^{p} a_j$ in $O(1)$ per query.
The case when $v < a_i$ is similar.
Now we must handle what happens when $a$ is not originally sorted. If we figure out, for all $i$, the position $p_i$ such that $a_i$ would be in the sorted version of $a$, then we can simply sort $a$ (getting an array $a'$) and view every query changing $a_i$ to $j$ as a query changing $a'_{p_i}$. There are a number of ways to find $p$: one way is to sort a list $c$ of the numbers from $1$ to $N$ by the value $a_i$; then if $c_j$ is the $j$th number in the list, the $p_{c_j} = j$.
We must also remember to compute $S$ and the prefix sums using $a'$.
Overall time complexity is $O((N + Q) \log N)$ due to sorting and binary search.
My C++ Code is below. The usage of the built in C++ lower_bound function can greatly simplify our implementation.
#include <bits/stdc++.h> using namespace std; const int MAXN = 1.5e5 + 13; typedef long long ll; int N, Q; int ord[MAXN], pos[MAXN]; ll arr[MAXN], pref[MAXN]; ll tot; int main() { ios_base::sync_with_stdio(false); cin.tie(0); cin >> N; for (int i = 0; i < N; i++) { cin >> arr[i]; } iota(ord, ord + N, 0); sort(ord, ord + N, [&](int i, int j) { return arr[i] < arr[j]; }); for (int i = 0; i < N; i++) { pos[ord[i]] = i; } sort(arr, arr + N); for (int i = 0; i < N; i++) { pref[i + 1] = pref[i] + arr[i]; tot += (i + 1) * arr[i]; } cin >> Q; while(Q--) { int idx; ll val; cin >> idx >> val; idx--; idx = pos[idx]; ll ans = tot; //index that val would be at in the new array int newidx = lower_bound(arr, arr + N, val) - arr - (bool) (val > arr[idx]); ans -= (idx + 1) * arr[idx]; if (newidx >= idx) { ans -= (pref[newidx + 1] - pref[idx + 1]); } else { ans += (pref[idx] - pref[newidx]); } ans += (newidx + 1) * val; cout << ans << '\n'; } return 0; }
My Python Code:
N = int(input()) arr = list(map(int, input().split())) ord = [i for i in range(N)] ord.sort(key = lambda x: arr[x]) pos = [0 for i in range(N)] for i in range(N): pos[ord[i]] = i arr.sort() def binary_search(x): #counts number of #s <x, or min index i st a[i] >= x lo = 0 hi = N while(hi > lo): mid = (hi + lo) // 2 if (arr[mid] >= x): hi = mid else: lo = mid + 1 return lo pref = [0 for i in range(N + 1)] tot = 0 for i in range(N): pref[i + 1] = pref[i] + arr[i] tot += (i + 1) * arr[i] Q = int(input()) for i in range(Q): idx, val = map(int, input().split()) idx -= 1 idx = pos[idx] newidx = binary_search(val) if (val > arr[idx]): newidx -= 1 ans = tot ans -= (idx + 1) * arr[idx] if (newidx >= idx): ans -= (pref[newidx + 1] - pref[idx + 1]) else: ans += (pref[idx] - pref[newidx]) ans += (newidx + 1) * val print(ans)
Slightly shorter if bisect is used:
import bisect N = int(input()) arr = list(map(int, input().split())) ord = [i for i in range(N)] ord.sort(key = lambda x: arr[x]) pos = [0 for i in range(N)] for i in range(N): pos[ord[i]] = i arr.sort() pref = [0 for i in range(N + 1)] tot = 0 for i in range(N): pref[i + 1] = pref[i] + arr[i] tot += (i + 1) * arr[i] Q = int(input()) for i in range(Q): idx, val = map(int, input().split()) idx -= 1 idx = pos[idx] newidx = bisect.bisect_left(arr, val) if (val > arr[idx]): newidx -= 1 ans = tot ans -= (idx + 1) * arr[idx] if (newidx >= idx): ans -= (pref[newidx + 1] - pref[idx + 1]) else: ans += (pref[idx] - pref[newidx]) ans += (newidx + 1) * val print(ans)
Danny Mittal's Java code:
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.Arrays; import java.util.StringTokenizer; import java.util.TreeMap; public class ArrayQueriesSilver { public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); int n = Integer.parseInt(in.readLine()); Long[] xs = Arrays.stream(in.readLine().split(" ")).map(Long::parseLong).toArray(Long[]::new); Long[] sorted = xs.clone(); Arrays.sort(sorted); long base = 0; long[] sums = new long[n + 1]; TreeMap<Long, Integer> treeMap = new TreeMap<>(); for (int j = 0; j < n; j++) { sums[j + 1] = sums[j] + sorted[j]; base += ((long) (j + 1)) * sorted[j]; treeMap.put(sorted[j], j); } treeMap.put(Long.MIN_VALUE, -1); StringBuilder out = new StringBuilder(); for (int q = Integer.parseInt(in.readLine()); q > 0; q--) { StringTokenizer tokenizer = new StringTokenizer(in.readLine()); int j = Integer.parseInt(tokenizer.nextToken()) - 1; long prev = xs[j]; long next = Long.parseLong(tokenizer.nextToken()); int prevIndex = treeMap.get(prev); int nextIndex = treeMap.lowerEntry(next).getValue() + 1; long answer = base - (((long) (prevIndex + 1)) * prev) + (((long) (nextIndex + (nextIndex > prevIndex ? 0 : 1))) * next) - (sums[nextIndex] - sums[prevIndex + (nextIndex > prevIndex ? 1 : 0)]); out.append(answer).append('\n'); } System.out.print(out); } }