(Analysis by Benjamin Qi)

Partial Credit: $O(N^3)$

Maintain an $N\times N$ boolean matrix storing which pairs of cows are friends, as well as the number of cows that each cow is friends with. After a cow leaves, update the matrix and the counts for all cows that are friends with that cow in $O(N^2)$ time. The answer can be derived from the counts in $O(N)$ time.

Partial Credit: $O(N^2)$

Construct a rooted tree with $2N-1$ nodes, one for each cow and each pair in the original tree. Color the nodes corresponding to cows black, and the nodes corresponding to pairs white. Connect each node corresponding to a pair to the nodes corresponding to the cows in that pair. For convenience of implementation, root the new tree at the black node corresponding to cow $N$.

When a cow leaves, remove the black node corresponding to that cow and merge all of the white nodes adjacent to that cow into a single white node, so that the new graph remains a tree. Notice that two cows are friends if and only if their black nodes share a white node as a neighbor in the rooted tree. So at every point in time, the number of triples is the number of paths of five nodes $(a,b,c,d,e)$ in the rooted tree where $a,c,e$ are black, $b,d$ are white, and all of $a,c,e$ are distinct.

For each white node $u$, define $mass[u]$ to be the number of black nodes immediately below it (equivalently, the number of neighbors of $u$ minus one). Given the new tree, we can count the total number of paths of the desired form by splitting into three cases:

  1. $b=d$. Then $a,c,e$ can be any three distinct neighbors of $b$. The number of paths of this form is equal to
    $$\sum_{b:b\text{ is white}}(mass[b]+1)\cdot mass[b]\cdot (mass[b]-1).$$
  2. $b\neq d$ and $c$ is above (closer to the root than) both $b$ and $d$. The number of paths of this form is equal to
    $$2\cdot \sum_{c:c\text{ is black}}\sum_{b<d,b\text{ and }d\text{ are children of }c}mass[b]\cdot mass[d].$$
  3. $b\neq d$ and one of $b$ and $d$ is above $c$. The number of paths of this form is equal to
    $$2\cdot \sum_{a:a\text{ is white}} mass[a]\cdot combinedCombinedMass[a],$$
    where $combinedCombinedMass[a]$ is the sum of the masses of all white nodes distance 2 below $a$.

We can compute the contributions of each case in $O(N)$ time given the rooted tree.

Full Credit: $O(N\alpha(N))$

For this subtask, we process each update in $O(\alpha(N)\cdot (\text{number of nodes removed from the rooted tree}))$ time (amortized). To achieve this, we additionally maintain $combinedMass[x]$ for each black node $x$, which we define to be the sum of $mass[y]$ over all children $y$ of $x$. When we merge all neighbors of a black node $x$ into a white node $w$, we need to:

  1. remove paths corresponding to nodes no longer in the tree
  2. compute $mass[w]$ and $combinedCombinedMass[w]$
  3. update $combinedMass[parent[find(x)]]$, where $parent[find(x)]$ is the black node immediately above $w$
  4. update $combinedCombinedMass[find(parent[find(x)])]$, where $find(parent[find(x)])$ is the white node immediately above the black node from the previous step

Note that the number of values we need to update is proportional to the number of nodes removed from the rooted tree plus a constant. Thus, the overall time complexity is $O(N\alpha(N))$. See the code below for details on the merging process.

Note on Implementation: In the code below, we initially label each white node with the label of the cow in the corresponding pair that is lower in the rooted tree. When we merge two white nodes with labels $a$ and $b$, we use a disjoint set union (DSU) data structure to assign the resulting white node the label of the cow out of $a$ and $b$ that is higher in the rooted tree. Also, the code below doesn't use union-by-rank in the DSU for ease of implementation, which makes the actual time complexity slightly worse.

Danny's code:

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
public class TriplesOfCows {
    static int n;
    static List<Integer>[] adj;
    static int[] parent;
    static int[] union;
    static long[] mass;
    static long[] combinedMass;
    static long[] combinedCombinedMass;
    static void dfs(int a) {
        for (int b : adj[a]) {
            parent[b] = a;
            adj[b].remove((Integer) a);
    static int find(int u) {
        if (union[u] != union[union[u]]) {
            union[u] = find(union[u]);
        return union[u];
    static long answer = 0;
    static void update(int u, long delta) {
        int p = parent[u];
        answer += 2L * (combinedMass[p] - mass[u]) * delta;
        combinedMass[p] += delta;
        if (p != n) {
            int a = find(p);
            answer += 2L * mass[a] * delta;
            combinedCombinedMass[a] += delta;
        answer += 2L * combinedCombinedMass[u] * delta;
        answer -= (mass[u] + 1L) * mass[u] * (mass[u] - 1L);
        mass[u] += delta;
        answer += (mass[u] + 1L) * mass[u] * (mass[u] - 1L);
    static void merge(int u) {
        int p = parent[u];
        int a = find(p);
        combinedMass[p] -= mass[u];
        answer -= 2L * combinedMass[p] * mass[u];
        combinedCombinedMass[a] -= mass[u];
        answer -= 2L * mass[a] * mass[u];
        answer -= 2L * mass[u] * combinedCombinedMass[u];
        answer += 2L * mass[a] * combinedCombinedMass[u];
        combinedCombinedMass[a] += combinedCombinedMass[u];
        answer -= (mass[u] + 1L) * mass[u] * (mass[u] - 1L);
        update(a, mass[u]);
        union[u] = a;
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(in.readLine());
        adj = new List[n + 1];
        for (int a = 1; a <= n; a++) {
            adj[a] = new ArrayList<>();
        for (int m = n - 1; m > 0; m--) {
            StringTokenizer tokenizer = new StringTokenizer(in.readLine());
            int a = Integer.parseInt(tokenizer.nextToken());
            int b = Integer.parseInt(tokenizer.nextToken());
        parent = new int[n];
        union = new int[n];
        mass = new long[n];
        combinedMass = new long[n + 1];
        combinedCombinedMass = new long[n];
        for (int a = 1; a < n; a++) {
            union[a] = a;
        StringBuilder out = new StringBuilder();
        for (int a = 1; a < n; a++) {
            update(a, 1L);
        for (int a = 1; a < n; a++) {
            for (int b : adj[a]) {
            update(find(a), -1L);