(Analysis by Nick Wu)

In this problem, we're given a set of paths on a tree and want to compute the maximum number of paths some node in the tree is on.

The naive solution to this problem would be, for each path, to manually increment the number of paths each node on the path is on by one. This will be $O(N^2)$ and is therefore too slow.

Let's first solve the easier case where the tree is actually a linked list. In this case, we can solve the problem in linear time by maintaining a prefix sum on the list - for the parent node in the list we increment a counter by 1, for the child node we decrement a counter by 1, and then when we've done this for all the paths we can iterate over the list and accumulate the prefix sums, maintaining the maximum.

Using this as inspiration, we can apply this prefix sum technique to a generic tree as follows - start by rooting the tree arbitrarily. For a path from node X to node Y, increment a "prefix" value at X and Y by 1. Decrement the prefix of the lowest common ancestor (LCA) of X and Y by 1, and decrement the prefix of the parent of the LCA by 1. Note that for a given node, the sum of the prefix values for every node in the given node's subtree is the number of paths that that node is on. We can compute the sum of the prefix values for all subtrees in linear time by processing the nodes in decreasing order of depth. Now, it remains to figure out how to compute the LCA of two nodes in sublinear time. There are several ways to do this --- for example we could use Tarjan's off-line LCA algorithm. The approach we use in this solution is as follows:

Given that the tree is rooted, let the "depth" of a node be the distance from the root to that node, and assume for simplicity that we want to find the LCA of two nodes X and Y which are the same depth. The following linear-time algorithm will find the LCA:

while X and Y are different:
  X = parent(X)
  Y = parent(Y)

If we could binary search on the depth of the number of parent calls we need to make, we could make this run in $O(\log N)$. We will simulate binary search as follows:

Let $f(N, S) = Z$ be the ancestor of $N$ such that $depth(Z) + 2^S = depth(N)$. If we root the tree, then we can compute $f(N, 0)$ directly for every node in the tree, and then $f(N, S) = f(f(N, S-1), S-1)$. Since the depth of the tree is $O(N)$, $S$ can only take on $O(\log N)$ values and therefore there are $O(N \log N)$ distinct tuples for $f$ that need to be computed.

We can now use $f$ to search for the highest ancestors of $X$ and $Y$ that are not common ancestors to both $X$ and $Y$ in $O(\log N)$ time as follows by finding the largest $S$ such that $f(X, S) \neq f(Y, S)$. Now, replace $X$ and $Y$ with $f(X, S)$ and $f(Y, S)$ and repeat this process until no such $S$ exists. Once that's the case, it must be true that either $X = Y$ or the parent of $X$ and $Y$ are identical.

Note that as we iterate on $S$, note that the values for $S$ must be monotonically decreasing - they clearly cannot be increasing, and if we pick the same value for $S$ twice in a row, we should have used $S+1$ instead. Therefore, if we iterate over $S$ in decreasing order, this procedure takes $O(\log N)$ time.

Therefore, after we root the tree, we can pre-compute $f$ in $O(N \log N)$ time, and then for each path, we do a single LCA query, so updating the prefix sums runs in $O(K \log N)$. The final accumulation of prefix sums by iterating on the nodes in decreasing order of depth takes $O(N)$ time, so this solution is $O((N+K) \log N)$.

Here is my code demonstrating this solution:

import java.io.*;
import java.util.*;
public class maxflow {
	static int[] p;
	static int[] depth;
	static int[][] anc;
	static int[] amt;
	static LinkedList<Integer>[] edges;
	static LinkedList<Integer> revOrder;
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new FileReader("maxflow.in"));
		PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter("maxflow.out")));
		StringTokenizer st = new StringTokenizer(br.readLine());
		int n = Integer.parseInt(st.nextToken());
		int k = Integer.parseInt(st.nextToken());
		p = new int[n+1];
		amt = new int[n+1];
		depth = new int[n+1];
		Arrays.fill(p, -1);
		p[0] = p[1] = 0;
		anc = new int[n+1][17];
		edges = new LinkedList[n+1];
		revOrder = new LinkedList<Integer>();
		for(int i = 0; i < edges.length; i++) {
			edges[i] = new LinkedList<Integer>();
		for(int a = 1; a < n; a++) {
			st = new StringTokenizer(br.readLine());
			int x = Integer.parseInt(st.nextToken());
			int y = Integer.parseInt(st.nextToken());
		for(int i = 0; i < k; i++) {
			st = new StringTokenizer(br.readLine());
			int x = Integer.parseInt(st.nextToken());
			int y = Integer.parseInt(st.nextToken());
			int lca = lca(x, y);
		int ret = 0;
		for(int i = 1; i <= n; i++) {
			ret = Math.max(ret, amt[i]);
	public static void compute() {
		while(!revOrder.isEmpty()) {
			int curr = revOrder.removeFirst();
			amt[p[curr]] += amt[curr];
	public static int lca(int a, int b) {
		if(depth[a] > depth[b]) {
			return lca(b, a);
		if(depth[a] < depth[b]) {
			b = getP(b, depth[a]);
		for(int k = 16; k > 0; k--) {
			while(anc[a][k] != anc[b][k]) {
				a = anc[a][k];
				b = anc[b][k];
		while(a != b) {
			a = p[a];
			b = p[b];
		return a;
	public static int getP(int curr, int wantedD) {
		for(int k = 16; depth[curr] != wantedD; k--) {
			while(depth[curr] - (1<<k) >= wantedD) {
				curr = anc[curr][k];
		return curr;
	public static void genLCA() {
		for(int i = 1; i < p.length; i++) {
			anc[i][0] = p[i];
		for(int j = 1; j < anc[0].length; j++) {
			for(int i = 1; i < p.length; i++) {
				anc[i][j] = anc[anc[i][j-1]][j-1];
	public static void bfs() {
		LinkedList<Integer> q = new LinkedList<Integer>();
		while(!q.isEmpty()) {
			int curr = q.removeFirst();
			for(int child: edges[curr]) {
				if(p[child] == -1) {
					p[child] = curr;
					depth[child] = 1 + depth[curr];

Additional analysis by Jesse van Dobben: Alternatively, the problem can be solved by using a two-pass depth first search along with a one-dimensional range query data structure. (For instance, you can use a Fenwick tree. I chose to reuse the tree I constructed in the solution to the third problem, Counting Haybales.) The algorithm works as follows: first we do a simple depth-first search where we label the vertices in the order in which we encounter them (that is, by starting time). Now a subtree corresponds to an interval of starting times. This is the most important idea behind this solution: the amount of flow going into a subtree can now be calculated by range query. Note that the answer does not change if we invert the direction of some of the $K$ flows. Thus, we may orient all the flows from lower to higher starting time. In the second DFS we will count for every vertex $v$ the amount of flow going into $v$, plus the amount of flow originating at $v$. We say a flow is "active" at a certain point during the second DFS if we have already started processing its first endpoint and have not yet completed processing its second endpoint. Whenever we go into a subtree during the second DFS, the amount of flow going into that subtree can be calculated as the number of currently active flows for which the destination lies within that subtree, so we query for the time interval of that subtree. Similarly, whenever we return from a subtree, the amount of flow coming back from that subtree can be calculated as the number of currently active flows for which the source lies within that subtree, which again corresponds to a time interval. This way we can compute the total amount of flow going into the current vertex using two one-dimensional range query data structures (one to list the active flows by their source and one to list the active flows by their destination). Don't forget to add the amount of flow originating at the current vertex, and we are done. I have rewritten the code for the sake of readability. (For instance, in my original solution I reused the data structure I wrote for the third problem, which was already quite a mess.)

#include <fstream>
#include <vector>
#include <algorithm>

using namespace std;

const int MAX_N = 50000;

ifstream fin("maxflow.in");
ofstream fout("maxflow.out");

struct partialSumTree {
  // data structure for changing and querying partial sums of a sequence
  // consiting of R - L elements, indexed L through R - 1.
  int L, R, half, sum;
  partialSumTree *left, *right;
  partialSumTree(int l, int r) : L(l), R(r), half((L + R)/2), sum(0) {
    if (half == L) {
      left = right = NULL;
    else {
      left = new partialSumTree(L, half);
      right = new partialSumTree(half, R);
  void updateValue(int idx, int delta) {
    if (idx < L || idx >= R) return;
    sum += delta;
    if (half != L) {
      (idx < half ? left : right)->updateValue(idx, delta);
  // get partial sum from A to B - 1
  int getSum(int A, int B) {
    if (A >= R || B <= L) return 0;
    if (A <= L && B >= R) return sum;
    return left->getSum(A, B) + right->getSum(A, B);

vector<int> neighbours[MAX_N];
int startTime[MAX_N];
int endTime[MAX_N];

int firstPassDFS(int curNode, int curTime) {
  if (startTime[curNode] != -1) return curTime;
  startTime[curNode] = curTime++;
  for (vector<int>::iterator it = neighbours[curNode].begin(); it != neighbours[curNode].end(); it++) {
    curTime = firstPassDFS(*it, curTime);
  return endTime[curNode] = curTime;

vector<int> beginPath[MAX_N], endPath[MAX_N];
int ans = -1;
partialSumTree *bySource, *byDestination;

int secondPassDFS(int curNode, int curTime) {
  if (startTime[curNode] != curTime) return curTime;
  int passingThrough = byDestination->getSum(startTime[curNode], endTime[curNode]);
  for (vector<int>::iterator it = beginPath[curNode].begin(); it != beginPath[curNode].end(); it++) {
    // unpack all paths starting at curNode
    bySource->updateValue(startTime[curNode], 1);
    byDestination->updateValue(startTime[*it], 1);
  for (vector<int>::iterator it = neighbours[curNode].begin(); it != neighbours[curNode].end(); it++) {
    int prevTime = curTime;
    curTime = secondPassDFS(*it, curTime);
    // add all paths that were started but not stopped in the subtree rooted at *it
    passingThrough += bySource->getSum(prevTime, curTime);
  for (vector<int>::iterator it = endPath[curNode].begin(); it != endPath[curNode].end(); it++) {
    // remove all paths ending at curNode
    bySource->updateValue(startTime[*it], -1);
    byDestination->updateValue(startTime[curNode], -1);
  ans = max(ans, passingThrough);
  return curTime;

int main() {
  int N, K;
  fin >> N >> K;
  for (int i = 1; i < N; i++) {
    int x, y;
    fin >> x >> y;
    x--; y--;
  fill(startTime, startTime + N, -1);
  fill(endTime, endTime + N, -1);
  firstPassDFS(0, 0);
  bySource = new partialSumTree(0, N);
  byDestination = new partialSumTree(0, N);
  for (int i = 0; i < K; i++) {
    int s, t;
    fin >> s >> t;
    s--; t--;
    if (startTime[s] > startTime[t]) swap(s, t);
  secondPassDFS(0, 0);
  fout << ans << endl;
  return 0;