(Analysis by David Hu)

Note that it is optimal for Farmer John to milk his cows such that the cow with $i$th smallest milk production value spends $i$ minutes on the milking machine. Indeed, if there are two cows $i$ and $j$ such that $a_i > a_j$ but cow $i$ spends less minutes than cow $j$ on the milking machine, the total amount of milk Farmer John produces could be increased by swapping the amount of time cows $i$ and $j$ spend on the milking machine.

So the maximum amount of milk Farmer John can produce is $G(a) = \sum_{i=1}^{n} i \cdot a'_i$, where $a'$ is the array that results upon sorting $a$.

Let's first suppose that $a$ is originally sorted, and let the value of $G(a)$ initially be $S$.

Now let's see what happens when we replace some $a_i$ with some other value $v$. First suppose $v \geq a_i$. Then, in the sorted version of $a$, $v$ will belong in some position $p \geq i$, which we can find by binary search. Furthermore. all numbers originally in positions $i+1, i+2, \dots, p$ will shift over down to one position to positions $i, i+1, \dots, p-1$. As a result, $G(a)$ will now become $S - i \cdot a_i - \sum_{j=i+1}^{p} a_j + p \cdot v$. We can use prefix sums to compute $\sum_{j=i+1}^{p} a_j$ in $O(1)$ per query.

The case when $v < a_i$ is similar.

Now we must handle what happens when $a$ is not originally sorted. If we figure out, for all $i$, the position $p_i$ such that $a_i$ would be in the sorted version of $a$, then we can simply sort $a$ (getting an array $a'$) and view every query changing $a_i$ to $j$ as a query changing $a'_{p_i}$. There are a number of ways to find $p$: one way is to sort a list $c$ of the numbers from $1$ to $N$ by the value $a_i$; then if $c_j$ is the $j$th number in the list, the $p_{c_j} = j$.

We must also remember to compute $S$ and the prefix sums using $a'$.

Overall time complexity is $O((N + Q) \log N)$ due to sorting and binary search.

My C++ Code is below. The usage of the built in C++ lower_bound function can greatly simplify our implementation.

#include <bits/stdc++.h>
 
using namespace std;
 
const int MAXN = 1.5e5 + 13;
typedef long long ll;
 
int N, Q;
int ord[MAXN], pos[MAXN];
ll arr[MAXN], pref[MAXN];
ll tot;
 
int main()
{
    ios_base::sync_with_stdio(false); cin.tie(0);
    cin >> N;
    for (int i = 0; i < N; i++)
    {
        cin >> arr[i];
    }
    iota(ord, ord + N, 0);
    sort(ord, ord + N, [&](int i, int j)
    {
        return arr[i] < arr[j];
    });
    for (int i = 0; i < N; i++)
    {
        pos[ord[i]] = i;
    }
    sort(arr, arr + N);
    for (int i = 0; i < N; i++)
    {
        pref[i + 1] = pref[i] + arr[i];
        tot += (i + 1) * arr[i];
    }
    cin >> Q;
    while(Q--)
    {
        int idx; ll val;
        cin >> idx >> val; idx--;
        idx = pos[idx];
        ll ans = tot;
        //index that val would be at in the new array
        int newidx = lower_bound(arr, arr + N, val) - arr - (bool) (val > arr[idx]);
        ans -= (idx + 1) * arr[idx];
        if (newidx >= idx) 
        {
            ans -= (pref[newidx + 1] - pref[idx + 1]);
        }
        else 
        {
            ans += (pref[idx] - pref[newidx]);
        }
        ans += (newidx + 1) * val;
        cout << ans << '\n';
    }
    return 0;
}

My Python Code:

N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
    pos[ord[i]] = i
arr.sort()

def binary_search(x): #counts number of #s <x, or min index i st a[i] >= x
    lo = 0
    hi = N
    while(hi > lo):
        mid = (hi + lo) // 2
        if (arr[mid] >= x):
            hi = mid
        else:
            lo = mid + 1
    return lo

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
    pref[i + 1] = pref[i] + arr[i]
    tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
    idx, val = map(int, input().split())
    idx -= 1
    idx = pos[idx]
    newidx = binary_search(val)
    if (val > arr[idx]):
        newidx -= 1
    ans = tot
    ans -= (idx + 1) * arr[idx]
    if (newidx >= idx):
        ans -= (pref[newidx + 1] - pref[idx + 1])
    else:
        ans += (pref[idx] - pref[newidx])
    ans += (newidx + 1) * val
    print(ans)

Slightly shorter if bisect is used:

import bisect
 
N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
    pos[ord[i]] = i
arr.sort()

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
    pref[i + 1] = pref[i] + arr[i]
    tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
    idx, val = map(int, input().split())
    idx -= 1
    idx = pos[idx]
    newidx = bisect.bisect_left(arr, val)
    if (val > arr[idx]):
        newidx -= 1
    ans = tot
    ans -= (idx + 1) * arr[idx]
    if (newidx >= idx):
        ans -= (pref[newidx + 1] - pref[idx + 1])
    else:
        ans += (pref[idx] - pref[newidx])
    ans += (newidx + 1) * val
    print(ans)

Danny Mittal's Java code:

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
import java.util.TreeMap;
 
public class ArrayQueriesSilver {
 
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(in.readLine());
        Long[] xs = Arrays.stream(in.readLine().split(" ")).map(Long::parseLong).toArray(Long[]::new);
        Long[] sorted = xs.clone();
        Arrays.sort(sorted);
        long base = 0;
        long[] sums = new long[n + 1];
        TreeMap<Long, Integer> treeMap = new TreeMap<>();
        for (int j = 0; j < n; j++) {
            sums[j + 1] = sums[j] + sorted[j];
            base += ((long) (j + 1)) * sorted[j];
            treeMap.put(sorted[j], j);
        }
        treeMap.put(Long.MIN_VALUE, -1);
        StringBuilder out = new StringBuilder();
        for (int q =  Integer.parseInt(in.readLine()); q > 0; q--) {
            StringTokenizer tokenizer = new StringTokenizer(in.readLine());
            int j = Integer.parseInt(tokenizer.nextToken()) - 1;
            long prev = xs[j];
            long next = Long.parseLong(tokenizer.nextToken());
            int prevIndex = treeMap.get(prev);
            int nextIndex = treeMap.lowerEntry(next).getValue() + 1;
            long answer = base
                    - (((long) (prevIndex + 1)) * prev)
                    + (((long) (nextIndex + (nextIndex > prevIndex ? 0 : 1))) * next)
                    - (sums[nextIndex] - sums[prevIndex + (nextIndex > prevIndex ? 1 : 0)]);
            out.append(answer).append('\n');
        }
        System.out.print(out);
    }
}