Let's call a node where Bessie watches Cowflix active.
For the $N \leq 5000$ subtask, we can compute the answer in $O(N)$ for each $k$ via tree DP, for an overall $\mathcal O(N^2)$ algorithm. For each node $a$, $dp(x, a)$ will be the minimum cost to handle watching Cowflix in all the necessary locations in $a$'s subtree, where:
These can then be computed as follows:
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; public class WatchingCowflix { static boolean[] active; static List<Integer>[] adj; static int[][][] dp; static void dfs(int k, int a, int parent) { dp[k][1][a] = k + 1; for (int b : adj[a]) { if (b != parent) { dfs(k, b, a); dp[k][0][a] += Math.min(dp[k][0][b], dp[k][1][b]); dp[k][1][a] += Math.min(dp[k][0][b], dp[k][1][b] - k); } } if (active[a]) { dp[k][0][a] = Integer.MAX_VALUE; } } public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); int n = Integer.parseInt(in.readLine()); String places = in.readLine(); active = new boolean[n + 1]; adj = new List[n + 1]; // choose root that is active for convenience int r = -1; for (int a = 1; a <= n; a++) { active[a] = places.charAt(a - 1) == '1'; if (active[a]) { r = a; } adj[a] = new ArrayList<>(); } for (int j = n - 1; j > 0; j--) { StringTokenizer tokenizer = new StringTokenizer(in.readLine()); int a = Integer.parseInt(tokenizer.nextToken()); int b = Integer.parseInt(tokenizer.nextToken()); adj[a].add(b); adj[b].add(a); } dp = new int[n + 1][2][n + 1]; for (int k = 1; k <= n; k++) { dfs(k, r, 0); int answer = dp[k][1][r]; System.out.println(answer); } } }
To solve for $N \leq 2 \cdot 10^5$, we can optimize the $\mathcal O(N^2)$ solution. Given a fixed $k$, notice that if any two active nodes are at distance at most $k$, then it's always beneficial to connect them into a single subscription, because we spend at most $k - 1$ dollars in connecting them but save $k$ for one less subscription.
This means that we could consider automatically connecting all pairs of active nodes at distance at most $k$ prior to the DP. After we do this, if we consider each prior-connected component of active nodes to be a single node, then notice that we must have at most $\frac {2N} k$ active nodes, because every remaining active node must be at a distance of at least $k + 1$ from all other active nodes, which means that each node in the graph is at distance $\leq \frac k 2$ from at most one active node, so since every active node must have at least $\frac k 2$ other nodes within a distance of $\frac k 2$ (assuming that there is more than one active node), the number of nodes in the graph is behind from below by
This means that, if we consider rooting the tree, the number of additional nodes that "branch" to different active nodes (i.e., have multiple children with active nodes in their subtrees) is also at most $\frac {2N} k$, because each such "branching" node merges at least two active nodes into a single subtree, which can only be done $(\text{number of active nodes}) - 1$ times.
We can apply this by first repeatedly removing leaves that aren't active nodes, then compressing chains of nodes that are just paths (i. e. compress chains of the form $a_0, a_1, \ldots, a_m$ where $a_k$ is only adjacent to $a_{k - 1}$ and $a_{k + 1}$ for $0 < k < m$ into a single edge of length $m$ between $a_0$ and $a_m$). After we do this, the only remaining nodes will be active nodes and nodes that "branch" to different active nodes. If we do this, then connect active nodes that are at distance at most $k$ apart, we will have a tree of size $O(\frac N k)$ on which we can compute our tree DP in $O(\frac N k)$ time.
Therefore, we can write the following algorithm. On the initial tree, first repeatedly remove leaves that aren't active nodes, then compress paths. Following that, precompute with a DFS for each node $a$, the minimum $k$ such that $a$ is involved in a length $\leq k$ path between two active nodes. Finally, loop through $k$ from $1$ to $N$, removing the nodes that can be removed, then compute the answer for that $k$ using tree DP on the remaining nodes.
The precomputation is all $O(N)$, and the tree DPs are computed in total time $O(\sum_{k = 1}^N \frac N k) = O(N\lg N)$, yielding an $O(N\lg N)$ solution.
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; public class WatchingCowflix { static int n; static boolean[] active; static Set<Integer>[] children; static boolean[] removed; static int[] parent; static int[] parentDist; static int[] depth; static void dfsRemoveInactiveLeaves(int a) { for (int b : children[a].toArray(new Integer[0])) { parent[b] = a; parentDist[b] = 1; depth[b] = depth[a] + 1; children[b].remove(a); dfsRemoveInactiveLeaves(b); } if (!active[a] && children[a].isEmpty()) { children[parent[a]].remove(a); removed[a] = true; } } static void compressPaths() { for (int a = 1; a <= n; a++) { if (!removed[a] && !active[a] && children[a].size() == 1) { int b = -1; for (int c : children[a]) { b = c; } parent[b] = parent[a]; parentDist[b] += parentDist[a]; children[parent[a]].remove(a); children[parent[a]].add(b); removed[a] = true; } } } static int[] nearestActiveDist; static int[] secondNearestActiveDist; static void dfsCalculateNearestActiveDists(int a) { nearestActiveDist[a] = active[a] ? 0 : Integer.MAX_VALUE; secondNearestActiveDist[a] = nearestActiveDist[a]; for (int b : children[a]) { dfsCalculateNearestActiveDists(b); int candidate = nearestActiveDist[b] + parentDist[b]; if (candidate < nearestActiveDist[a]) { secondNearestActiveDist[a] = nearestActiveDist[a]; nearestActiveDist[a] = candidate; } else if (candidate < secondNearestActiveDist[a]) { secondNearestActiveDist[a] = candidate; } } } static int[] whenEdgeToParentForced; static void dfsCalculateWhenForced(int a, int nearestActiveDistAbove) { whenEdgeToParentForced[a] = nearestActiveDistAbove + nearestActiveDist[a]; for (int b : children[a]) { int otherNearestActiveDist = nearestActiveDist[b] + parentDist[b] == nearestActiveDist[a] ? secondNearestActiveDist[a] : nearestActiveDist[a]; dfsCalculateWhenForced(b, Math.min(otherNearestActiveDist, nearestActiveDistAbove) + parentDist[b]); } } static int[] answers; static void computeAnswers() { List<Integer>[] buckets = new List[n + 2]; for (int k = 0; k <= n + 1; k++) { buckets[k] = new ArrayList<>(); } int[] whenNodeForced = new int[n + 1]; int[] amtNodesForced = new int[n + 1]; int[] amtEdgesForced = new int[n + 1]; TreeSet<Integer> nodes = new TreeSet<>((a, b) -> { if (depth[a] != depth[b]) { return depth[b] - depth[a]; } else { return a - b; } }); for (int a = 1; a <= n; a++) { if (!removed[a]) { nodes.add(a); if (active[a]) { whenNodeForced[a] = 0; } else { whenNodeForced[a] = whenEdgeToParentForced[a]; } int whenIrrelevant = whenEdgeToParentForced[a]; for (int b : children[a]) { whenIrrelevant = Math.max(whenIrrelevant, whenEdgeToParentForced[b]); whenNodeForced[a] = Math.min(whenNodeForced[a], whenEdgeToParentForced[b]); } buckets[whenIrrelevant].add(a); amtNodesForced[whenNodeForced[a]]++; if (depth[a] != 0) { amtNodesForced[whenEdgeToParentForced[a]] += parentDist[a] - 1; amtEdgesForced[whenEdgeToParentForced[a]] += parentDist[a]; } } } int currNodesForced = amtNodesForced[0]; int currEdgesForced = 0; int[][] dp = new int[2][n + 1]; for (int k = 1; k <= n; k++) { buckets[k].forEach(nodes::remove); currNodesForced += amtNodesForced[k]; currEdgesForced += amtEdgesForced[k]; int forcedComponents = currNodesForced - currEdgesForced; int answer = (k * forcedComponents) + currNodesForced; for (int a : nodes) { if (whenNodeForced[a] <= k) { dp[0][a] = Integer.MAX_VALUE; // no possibility of not having node a // don't have to add (k + 1) to dp[1][a] because already accounted for // via forcedComponents and currNodesForced respectively } else { dp[1][a] += k + 1; } if (whenEdgeToParentForced[a] <= k) { answer += dp[1][a]; } else { dp[0][parent[a]] += Math.min(dp[0][a], dp[1][a]); dp[1][parent[a]] += Math.min(dp[0][a], Math.min(dp[1][a], dp[1][a] - k + (parentDist[a] - 1))); } dp[0][a] = 0; dp[1][a] = 0; } answers[k] = answer; } } public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); n = Integer.parseInt(in.readLine()); String places = in.readLine(); active = new boolean[n + 1]; children = new Set[n + 1]; // choose root that is active for convenience int r = -1; for (int a = 1; a <= n; a++) { active[a] = places.charAt(a - 1) == '1'; if (active[a]) { r = a; } children[a] = new HashSet<>(); } for (int j = n - 1; j > 0; j--) { StringTokenizer tokenizer = new StringTokenizer(in.readLine()); int a = Integer.parseInt(tokenizer.nextToken()); int b = Integer.parseInt(tokenizer.nextToken()); children[a].add(b); children[b].add(a); } removed = new boolean[n + 1]; parent = new int[n + 1]; parentDist = new int[n + 1]; depth = new int[n + 1]; dfsRemoveInactiveLeaves(r); compressPaths(); nearestActiveDist = new int[n + 1]; secondNearestActiveDist = new int[n + 1]; dfsCalculateNearestActiveDists(r); whenEdgeToParentForced = new int[n + 1]; dfsCalculateWhenForced(r, 0); answers = new int[n + 1]; computeAnswers(); StringBuilder out = new StringBuilder(); for (int k = 1; k <= n; k++) { out.append(answers[k]).append('\n'); } System.out.print(out); } }
Note: It was also possible to solve this problem in $O(N\sqrt N)$ time. Suppose we augment the tree DP described in the $N\le 5000$ subtask to return the minimum number of components used in an optimal solution; as $k$ increases, this number stays the same or goes down. Using the observation that the answers are concave down, it can be proven that the pseudocode below makes $O(\sqrt N)$ calls to the $O(N)$ time tree DP:
def divide_and_conquer(l, r): # compute solutions for k = l+1 ... r-1, # assuming solutions for k = l and k = r already computed if min number of components in solutions for l and r are the same: linearly interpolate answers between l and r else: m = (l+r) // 2 compute {answer, min number of components} for k = m using tree DP divide_and_conquer(l, m) divide_and_conquer(m, r)
This was enough to receive full credit if the tree DP was properly optimized. One way to optimize the tree DP is to first relabel the tree so that the parent of each node has label less than the label of the node itself, so that the DP can be implemented using a single for loop (as opposed to DFS).