(Analysis by Nick Wu)

This problem asks to count the number of 3-colorings in a tree where some nodes already have a fixed color.

To start, root the tree arbitrarily, so we now wish to count the number of 3-colorings of a rooted tree. Note that if we fix the color of the root node, all of the subtrees of the root node can be colored independently.

This gives way to a DP approach to this problem. Let $f(v, c)$ be the number of ways to color the subtree rooted at vertex $v$, where vertex $v$ has color $c$. $f(v, c)$ is therefore $\displaystyle\prod_u \displaystyle\sum_{c' \neq c} f(u, c')$, where $u$ iterates over all children of $v$ and $c'$ are the other colors available.

The only remaining thing to to be careful of is handling nodes which are already colored.

import java.io.*;
import java.util.*;
public class barnpainting {
public static void main(String[] args) throws IOException {
PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter("barnpainting.out")));
int n = Integer.parseInt(st.nextToken());
int k = Integer.parseInt(st.nextToken());
color = new int[n];
Arrays.fill(color, -1);
dp = new long[n][3];
for(int i = 0; i < n; i++) {
Arrays.fill(dp[i], -1);
}
for(int i = 1; i < n; i++) {
int a = Integer.parseInt(st.nextToken())-1;
int b = Integer.parseInt(st.nextToken())-1;
}
for(int i = 0; i < k; i++) {
int a = Integer.parseInt(st.nextToken())-1;
int c = Integer.parseInt(st.nextToken())-1;
color[a] = c;
}
long ret = solve(0, 0, -1, -1) + solve(0, 1, -1, -1) + solve(0, 2, -1, -1);
pw.println(ret % MOD);
pw.close();
}

public static long solve(int currV, int currC, int parV, int parC) {
if(currC == parC || (color[currV] >= 0 && currC != color[currV])) return 0;
if(dp[currV][currC] >= 0) {
return dp[currV][currC];
}
dp[currV][currC] = 1;
for(int out: edges[currV]) {
if(out == parV) continue;
long canColor = 0;
for(int c = 0; c < 3; c++) {
canColor += solve(out, c, currV, currC);
canColor %= MOD;
}
dp[currV][currC] *= canColor;
dp[currV][currC] %= MOD;
}
return dp[currV][currC];
}

static long[][] dp;
static final int MOD = 1000000007;