Let med(a,b,c) denote the median of a, b, and c.
Even though we only care about the median approximation of the root, we can also define a median approximation for each node x. This is the value on node x after it swaps to the median of its two (or possibly zero) children.
Let costi(x) be the cost to turn node i into having value x.
Lets first coordinate compress the initial values and query values to O(n+q) unique values. Let dp[i][j] be the minimum cost to make the median approximation of i equal to j. Our answer for a given query is dp[1][m].
Our base cases are when there are zero children:
Our transitions are when there are two children:
There are O((n+q)2) states and transitions are O((n+q)3), for a total complexity of O((n+q)4) (we only spend O((n+q)3) time transitioning for each node). This should pass n,q≤50.
Alex Fan's C++ code:
using namespace std; #include <iostream> #include <vector> #include <algorithm> #include <map> #define MAXN 101 int N, Q, a[MAXN], b[MAXN]; long long c[MAXN], dp[MAXN][MAXN]; map<int, int> m; vector<int> pos; long long med(int x, int y, int z) { return x ^ y ^ z ^ min({x, y, z}) ^ max({x, y, z}); } int main() { ios_base::sync_with_stdio(0); cin.tie(0); cin >> N; for(int i = 0;i < N;++i) { cin >> a[i] >> c[i]; pos.push_back(a[i]); } cin >> Q; for(int i = 0;i < Q;++i) { cin >> b[i]; pos.push_back(b[i]); } sort(pos.begin(), pos.end()); pos.resize(unique(pos.begin(), pos.end()) - pos.begin()); for(int i = 0;i < pos.size();++i) { m[pos[i]] = i; } for(int i = N - 1;i >= 0;--i) { // Calculate for node i if(2 * i + 1 >= N) { for(int j = 0;j < pos.size();++j) dp[i][j] = pos[j] == a[i] ? 0 : c[i]; continue; } for(int j = 0;j < pos.size();++j) dp[i][j] = 1e15; int l = 2 * i + 1; int r = 2 * i + 2; // Assuming node i's value is j, and its two children are a and b for(int j = 0;j < pos.size();++j) { long long cost = pos[j] == a[i] ? 0 : c[i]; for(int a = 0;a < pos.size();++a) { for(int b = 0;b < pos.size();++b) { int median = med(a, b, j); dp[i][median] = min(dp[i][median], cost + dp[l][a] + dp[r][b]); } } } } for(int i = 0;i < Q;++i) { cout << dp[0][m[b[i]]] << endl; } return 0; }
Suppose m is the current queried target value, and let pm(x) be a function that "partitions" x relative to m.
Then,
Essentially, the state of being less than, equal to, or greater than is preserved when taking the median. This means we don't care as much about the exact median approximation of each node, only how that approximation compares with m. We define dp[i][j] to be the minimum cost to turn the median approximation of node i into some x such that pm(x)=j (the dp array has size n×3). Our final answer is now dp[1][pm(m)]=dp[1][1].
Our dp remains mostly unchanged:
We redefine costi as follows:
Base case:
Transition:
Now we have O(n) states and O(33) transitions per query, resulting in a final complexity of O(33nq). This should pass with n,q≤1000.
Alex Fan's C++ code:
using namespace std; #include <iostream> #include <vector> #include <algorithm> #include <map> #define MAXN 2005 int N, Q, a[MAXN], f[MAXN]; long long c[MAXN], dp[MAXN][3], ans[MAXN]; long long med(int x, int y, int z) { return x ^ y ^ z ^ min({x, y, z}) ^ max({x, y, z}); } void merge(int p) { long long cost[3] = {c[p], c[p], c[p]}; cost[f[p]] = 0; if(2 * p + 1 >= N) { for(int i = 0;i < 3;++i) dp[p][i] = cost[i]; return; } int l = 2 * p + 1; int r = 2 * p + 2; dp[p][0] = dp[p][1] = dp[p][2] = 1e15; for(int i = 0;i < 3;++i) { for(int j = 0;j < 3;++j) { for(int k = 0;k < 3;++k) { dp[p][med(i, j, k)] = min(dp[p][med(i, j, k)], cost[i] + dp[l][j] + dp[r][k]); } } } return; } int main() { ios_base::sync_with_stdio(0); cin.tie(0); cin >> N; for(int i = 0;i < N;++i) { cin >> a[i] >> c[i]; } cin >> Q; for(int i = 0;i < Q;++i) { int uwu; cin >> uwu; for(int j = N - 1;j >= 0;--j) { f[j] = uwu == a[j] ? 1 : (uwu < a[j] ? 0 : 2); merge(j); } cout << dp[0][1] << endl; } return 0; }
We can notice that our dp is define by the costis, which is in turn defined in terms of the pm(ai) of that node. Additionally, if we change a costi, at most O(logn) dp values change. This is because only the dp values of the ancestors of a node are affected, and the tree has depth O(logn).
We can process the queries offline. Let mj be the j'th query after sorting unique query values in ascending order. Between trying to answer mj and mj+1, some of the costis would have changed. This change is due to pmj(ai) being different from pmj+1(ai).
If ai=mj, pmj(ai)=1 but pmj+1(ai)=0.
If mj<ai<mj+1, pmj(ai)=2 but pmj+1(ai)=0.
If ai=mj+1, pmj(ai)=2 but pmj+1(ai)=1.
We use these cases to update the costi values. Every node's pm(ai) starts at 2 and monotonically decreases until it reaches 0. This means each node "causes" at most 2 updates throughout answering the offline queries. As each update affects O(logn) other dp values, this makes the total complexity of updates O(nlogn).
Putting it all together, our final complexity is O(qlogq+nlogn). This should pass for n,q≤2⋅105.
Alex Fan's C++ code:
using namespace std; #include <iostream> #include <vector> #include <algorithm> #include <map> #define MAXN 200005 int N, Q, a[MAXN], f[MAXN]; long long c[MAXN], dp[MAXN][3], ans[MAXN]; vector<int> pos; map<int, vector<int>> nodes, qs; long long med(int x, int y, int z) { return x ^ y ^ z ^ min({x, y, z}) ^ max({x, y, z}); } void merge(int p) { long long cost[3] = {c[p], c[p], c[p]}; cost[f[p]] = 0; //f[p] is p_m(a_i) if(2 * p + 1 >= N) { for(int i = 0;i < 3;++i) dp[p][i] = cost[i]; return; } int l = 2 * p + 1; int r = 2 * p + 2; dp[p][0] = dp[p][1] = dp[p][2] = 1e15; for(int i = 0; i < 3; ++i) { for(int j = 0; j < 3; ++j) { for(int k = 0; k < 3; ++k) { dp[p][med(i, j, k)] = min(dp[p][med(i, j, k)], cost[i] + dp[l][j] + dp[r][k]); } } } return; } void update(int p, int state) { f[p] = state; while(true) { merge(p); if(!p) break; p = (p - 1) / 2; } return; } int main() { ios_base::sync_with_stdio(0); cin.tie(0); cin >> N; for(int i = 0;i < N;++i) { cin >> a[i] >> c[i]; nodes[a[i]].push_back(i); pos.push_back(a[i]); } cin >> Q; for(int i = 0;i < Q;++i) { int uwu; cin >> uwu; qs[uwu].push_back(i); pos.push_back(uwu); } sort(pos.begin(), pos.end()); pos.resize(unique(pos.begin(), pos.end()) - pos.begin()); for(int i = N - 1;i >= 0;--i) merge(i); for(int uwu : pos) { // Update all the states to equality first reverse(nodes[uwu].begin(), nodes[uwu].end()); for(auto owo : nodes[uwu]) { update(owo, 1); } // Answer the queries for(auto xwx : qs[uwu]) { ans[xwx] = dp[0][1]; } // Update all the states to greater than for(auto owo : nodes[uwu]) { update(owo, 2); } } for(int i = 0;i < Q;++i) { cout << ans[i] << endl; } return 0; }