Note that it is optimal for Farmer John to milk his cows such that the cow with ith smallest milk production value spends i minutes on the milking machine. Indeed, if there are two cows i and j such that ai>aj but cow i spends less minutes than cow j on the milking machine, the total amount of milk Farmer John produces could be increased by swapping the amount of time cows i and j spend on the milking machine.
So the maximum amount of milk Farmer John can produce is G(a)=∑ni=1i⋅a′i, where a′ is the array that results upon sorting a.
Let's first suppose that a is originally sorted, and let the value of G(a) initially be S.
Now let's see what happens when we replace some ai with some other value v. First suppose v≥ai. Then, in the sorted version of a, v will belong in some position p≥i, which we can find by binary search. Furthermore. all numbers originally in positions i+1,i+2,…,p will shift over down to one position to positions i,i+1,…,p−1. As a result, G(a) will now become S−i⋅ai−∑pj=i+1aj+p⋅v. We can use prefix sums to compute ∑pj=i+1aj in O(1) per query.
The case when v<ai is similar.
Now we must handle what happens when a is not originally sorted. If we figure out, for all i, the position pi such that ai would be in the sorted version of a, then we can simply sort a (getting an array a′) and view every query changing ai to j as a query changing a′pi. There are a number of ways to find p: one way is to sort a list c of the numbers from 1 to N by the value ai; then if cj is the jth number in the list, the pcj=j.
We must also remember to compute S and the prefix sums using a′.
Overall time complexity is O((N+Q)logN) due to sorting and binary search.
My C++ Code is below. The usage of the built in C++ lower_bound function can greatly simplify our implementation.
#include <bits/stdc++.h> using namespace std; const int MAXN = 1.5e5 + 13; typedef long long ll; int N, Q; int ord[MAXN], pos[MAXN]; ll arr[MAXN], pref[MAXN]; ll tot; int main() { ios_base::sync_with_stdio(false); cin.tie(0); cin >> N; for (int i = 0; i < N; i++) { cin >> arr[i]; } iota(ord, ord + N, 0); sort(ord, ord + N, [&](int i, int j) { return arr[i] < arr[j]; }); for (int i = 0; i < N; i++) { pos[ord[i]] = i; } sort(arr, arr + N); for (int i = 0; i < N; i++) { pref[i + 1] = pref[i] + arr[i]; tot += (i + 1) * arr[i]; } cin >> Q; while(Q--) { int idx; ll val; cin >> idx >> val; idx--; idx = pos[idx]; ll ans = tot; //index that val would be at in the new array int newidx = lower_bound(arr, arr + N, val) - arr - (bool) (val > arr[idx]); ans -= (idx + 1) * arr[idx]; if (newidx >= idx) { ans -= (pref[newidx + 1] - pref[idx + 1]); } else { ans += (pref[idx] - pref[newidx]); } ans += (newidx + 1) * val; cout << ans << '\n'; } return 0; }
My Python Code:
N = int(input()) arr = list(map(int, input().split())) ord = [i for i in range(N)] ord.sort(key = lambda x: arr[x]) pos = [0 for i in range(N)] for i in range(N): pos[ord[i]] = i arr.sort() def binary_search(x): #counts number of #s <x, or min index i st a[i] >= x lo = 0 hi = N while(hi > lo): mid = (hi + lo) // 2 if (arr[mid] >= x): hi = mid else: lo = mid + 1 return lo pref = [0 for i in range(N + 1)] tot = 0 for i in range(N): pref[i + 1] = pref[i] + arr[i] tot += (i + 1) * arr[i] Q = int(input()) for i in range(Q): idx, val = map(int, input().split()) idx -= 1 idx = pos[idx] newidx = binary_search(val) if (val > arr[idx]): newidx -= 1 ans = tot ans -= (idx + 1) * arr[idx] if (newidx >= idx): ans -= (pref[newidx + 1] - pref[idx + 1]) else: ans += (pref[idx] - pref[newidx]) ans += (newidx + 1) * val print(ans)
Slightly shorter if bisect is used:
import bisect N = int(input()) arr = list(map(int, input().split())) ord = [i for i in range(N)] ord.sort(key = lambda x: arr[x]) pos = [0 for i in range(N)] for i in range(N): pos[ord[i]] = i arr.sort() pref = [0 for i in range(N + 1)] tot = 0 for i in range(N): pref[i + 1] = pref[i] + arr[i] tot += (i + 1) * arr[i] Q = int(input()) for i in range(Q): idx, val = map(int, input().split()) idx -= 1 idx = pos[idx] newidx = bisect.bisect_left(arr, val) if (val > arr[idx]): newidx -= 1 ans = tot ans -= (idx + 1) * arr[idx] if (newidx >= idx): ans -= (pref[newidx + 1] - pref[idx + 1]) else: ans += (pref[idx] - pref[newidx]) ans += (newidx + 1) * val print(ans)
Danny Mittal's Java code:
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.Arrays; import java.util.StringTokenizer; import java.util.TreeMap; public class ArrayQueriesSilver { public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); int n = Integer.parseInt(in.readLine()); Long[] xs = Arrays.stream(in.readLine().split(" ")).map(Long::parseLong).toArray(Long[]::new); Long[] sorted = xs.clone(); Arrays.sort(sorted); long base = 0; long[] sums = new long[n + 1]; TreeMap<Long, Integer> treeMap = new TreeMap<>(); for (int j = 0; j < n; j++) { sums[j + 1] = sums[j] + sorted[j]; base += ((long) (j + 1)) * sorted[j]; treeMap.put(sorted[j], j); } treeMap.put(Long.MIN_VALUE, -1); StringBuilder out = new StringBuilder(); for (int q = Integer.parseInt(in.readLine()); q > 0; q--) { StringTokenizer tokenizer = new StringTokenizer(in.readLine()); int j = Integer.parseInt(tokenizer.nextToken()) - 1; long prev = xs[j]; long next = Long.parseLong(tokenizer.nextToken()); int prevIndex = treeMap.get(prev); int nextIndex = treeMap.lowerEntry(next).getValue() + 1; long answer = base - (((long) (prevIndex + 1)) * prev) + (((long) (nextIndex + (nextIndex > prevIndex ? 0 : 1))) * next) - (sums[nextIndex] - sums[prevIndex + (nextIndex > prevIndex ? 1 : 0)]); out.append(answer).append('\n'); } System.out.print(out); } }