Processing math: 100%
(Analysis by David Hu)

Note that it is optimal for Farmer John to milk his cows such that the cow with ith smallest milk production value spends i minutes on the milking machine. Indeed, if there are two cows i and j such that ai>aj but cow i spends less minutes than cow j on the milking machine, the total amount of milk Farmer John produces could be increased by swapping the amount of time cows i and j spend on the milking machine.

So the maximum amount of milk Farmer John can produce is G(a)=ni=1iai, where a is the array that results upon sorting a.

Let's first suppose that a is originally sorted, and let the value of G(a) initially be S.

Now let's see what happens when we replace some ai with some other value v. First suppose vai. Then, in the sorted version of a, v will belong in some position pi, which we can find by binary search. Furthermore. all numbers originally in positions i+1,i+2,,p will shift over down to one position to positions i,i+1,,p1. As a result, G(a) will now become Siaipj=i+1aj+pv. We can use prefix sums to compute pj=i+1aj in O(1) per query.

The case when v<ai is similar.

Now we must handle what happens when a is not originally sorted. If we figure out, for all i, the position pi such that ai would be in the sorted version of a, then we can simply sort a (getting an array a) and view every query changing ai to j as a query changing api. There are a number of ways to find p: one way is to sort a list c of the numbers from 1 to N by the value ai; then if cj is the jth number in the list, the pcj=j.

We must also remember to compute S and the prefix sums using a.

Overall time complexity is O((N+Q)logN) due to sorting and binary search.

My C++ Code is below. The usage of the built in C++ lower_bound function can greatly simplify our implementation.

#include <bits/stdc++.h>
 
using namespace std;
 
const int MAXN = 1.5e5 + 13;
typedef long long ll;
 
int N, Q;
int ord[MAXN], pos[MAXN];
ll arr[MAXN], pref[MAXN];
ll tot;
 
int main()
{
    ios_base::sync_with_stdio(false); cin.tie(0);
    cin >> N;
    for (int i = 0; i < N; i++)
    {
        cin >> arr[i];
    }
    iota(ord, ord + N, 0);
    sort(ord, ord + N, [&](int i, int j)
    {
        return arr[i] < arr[j];
    });
    for (int i = 0; i < N; i++)
    {
        pos[ord[i]] = i;
    }
    sort(arr, arr + N);
    for (int i = 0; i < N; i++)
    {
        pref[i + 1] = pref[i] + arr[i];
        tot += (i + 1) * arr[i];
    }
    cin >> Q;
    while(Q--)
    {
        int idx; ll val;
        cin >> idx >> val; idx--;
        idx = pos[idx];
        ll ans = tot;
        //index that val would be at in the new array
        int newidx = lower_bound(arr, arr + N, val) - arr - (bool) (val > arr[idx]);
        ans -= (idx + 1) * arr[idx];
        if (newidx >= idx) 
        {
            ans -= (pref[newidx + 1] - pref[idx + 1]);
        }
        else 
        {
            ans += (pref[idx] - pref[newidx]);
        }
        ans += (newidx + 1) * val;
        cout << ans << '\n';
    }
    return 0;
}

My Python Code:

N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
    pos[ord[i]] = i
arr.sort()

def binary_search(x): #counts number of #s <x, or min index i st a[i] >= x
    lo = 0
    hi = N
    while(hi > lo):
        mid = (hi + lo) // 2
        if (arr[mid] >= x):
            hi = mid
        else:
            lo = mid + 1
    return lo

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
    pref[i + 1] = pref[i] + arr[i]
    tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
    idx, val = map(int, input().split())
    idx -= 1
    idx = pos[idx]
    newidx = binary_search(val)
    if (val > arr[idx]):
        newidx -= 1
    ans = tot
    ans -= (idx + 1) * arr[idx]
    if (newidx >= idx):
        ans -= (pref[newidx + 1] - pref[idx + 1])
    else:
        ans += (pref[idx] - pref[newidx])
    ans += (newidx + 1) * val
    print(ans)

Slightly shorter if bisect is used:

import bisect
 
N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
    pos[ord[i]] = i
arr.sort()

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
    pref[i + 1] = pref[i] + arr[i]
    tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
    idx, val = map(int, input().split())
    idx -= 1
    idx = pos[idx]
    newidx = bisect.bisect_left(arr, val)
    if (val > arr[idx]):
        newidx -= 1
    ans = tot
    ans -= (idx + 1) * arr[idx]
    if (newidx >= idx):
        ans -= (pref[newidx + 1] - pref[idx + 1])
    else:
        ans += (pref[idx] - pref[newidx])
    ans += (newidx + 1) * val
    print(ans)

Danny Mittal's Java code:

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
import java.util.TreeMap;
 
public class ArrayQueriesSilver {
 
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(in.readLine());
        Long[] xs = Arrays.stream(in.readLine().split(" ")).map(Long::parseLong).toArray(Long[]::new);
        Long[] sorted = xs.clone();
        Arrays.sort(sorted);
        long base = 0;
        long[] sums = new long[n + 1];
        TreeMap<Long, Integer> treeMap = new TreeMap<>();
        for (int j = 0; j < n; j++) {
            sums[j + 1] = sums[j] + sorted[j];
            base += ((long) (j + 1)) * sorted[j];
            treeMap.put(sorted[j], j);
        }
        treeMap.put(Long.MIN_VALUE, -1);
        StringBuilder out = new StringBuilder();
        for (int q =  Integer.parseInt(in.readLine()); q > 0; q--) {
            StringTokenizer tokenizer = new StringTokenizer(in.readLine());
            int j = Integer.parseInt(tokenizer.nextToken()) - 1;
            long prev = xs[j];
            long next = Long.parseLong(tokenizer.nextToken());
            int prevIndex = treeMap.get(prev);
            int nextIndex = treeMap.lowerEntry(next).getValue() + 1;
            long answer = base
                    - (((long) (prevIndex + 1)) * prev)
                    + (((long) (nextIndex + (nextIndex > prevIndex ? 0 : 1))) * next)
                    - (sums[nextIndex] - sums[prevIndex + (nextIndex > prevIndex ? 1 : 0)]);
            out.append(answer).append('\n');
        }
        System.out.print(out);
    }
}