Processing math: 100%
(Analysis by Benjamin Qi)

We will use “node” interchangeably with “snowball.” Let’s start by representing the tree as an array. First, we can run a preorder traversal in O(N) time. Let st[x] denote the index (starting from one) of node x in the traversal and let en[x] denote the maximum index of any node in the subtree of v. Then the subtree of x corresponds exactly with all nodes with indices in the range [st[x],en[x]].

For a fixed color c, call a node ``special" if it is colored c and its parent is not colored c. For any node x, let sub[x] denote the number of nodes in the subtree of x. Then the number of nodes in its subtree that are colored c is given by one of the following:

We can rewrite the answer for a query for the subtree of x as the sum of two separate parts.

sub[x](# of special nodes above or equal to x)+(subtree sizes of special nodes below x),
over special nodes of all colors. Whenever a previously uncolored node is colored, we add it to the set of special nodes for that color and possibly remove some of the nodes in that set.

Part 1: getting (# of special nodes above or equal to x)

Whenever we add a special node, use a binary indexed tree (BIT) to add 1 to all nodes in the range [st[x],en[x]]. Then evaluating this quantity is equivalent to making a point query at st[x].

Part 2: getting (subtree sizes of special nodes below x)

Whenever we add a special node y, use a BIT to add sub[y] to the index st[y]. Then we simply need to query the sum of all values in the BIT in the range [st[x]+1,en[x]].

Since we make O(Q) updates to the sets and the two BIT's, our solution runs in O(N+QlogN). My code follows.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef vector<int> vi; 
 
#define FOR(i,a,b) for (int i = (a); i < (b); ++i)
#define F0R(i,a) FOR(i,0,a)
#define ROF(i,a,b) for (int i = (b)-1; i >= (a); --i)
#define R0F(i,a) ROF(i,0,a)
#define trav(a,x) for (auto& a: x)
 
#define pb push_back
#define ub upper_bound
#define s second
 
void setIO(string name) {
	ios_base::sync_with_stdio(0); cin.tie(0);
	freopen((name+".in").c_str(),"r",stdin);
	freopen((name+".out").c_str(),"w",stdout);
}

const int MX = 100005;

template<class T, int SZ> struct BIT {
	T bit[SZ+1];
	void upd(int pos, T x) {
		for (; pos <= SZ; pos += (pos&-pos)) 
		bit[pos] += x;
	}
	T sum(int r) {
		T res = 0; for (; r; r -= (r&-r)) 
			res += bit[r];
		return res;
	}
	T query(int l, int r) { 
		return sum(r)-sum(l-1); 
	}	
};
 
BIT<ll,MX> A,B;
map<int,int> col[MX];
int st[MX], en[MX],sub[MX];
int N,Q;
vi adj[MX];
int co;
 
void dfs(int x, int y) {
	st[x] = ++co;
	trav(t,adj[x]) if (t != y) dfs(t,x);
	en[x] = co;
	sub[x] = en[x]-st[x]+1;
}
 
void upd(int x, int y) {
	A.upd(st[x],y); A.upd(en[x]+1,-y);
	B.upd(st[x],y*sub[x]);
}
 
int main() {
	setIO("snowcow");
	cin >> N >> Q;
	F0R(i,N-1) {
		int a,b; cin >> a >> b;
		adj[a].pb(b), adj[b].pb(a);
	}
	dfs(1,0);
	F0R(i,Q) {
		int t; cin >> t;
		if (t == 1) {
			int x,c; cin >> x >> c;
			auto it = col[c].ub(st[x]);
			if (it != begin(col[c]) && en[prev(it)->s] >= en[x]) continue;
			while (it != end(col[c]) && en[it->s] <= en[x]) {
				upd(it->s,-1);
				col[c].erase(it++);
			}
			col[c][st[x]] = x; upd(x,1);
		} else {
			int x; cin >> x;
			cout << sub[x]*A.sum(st[x])+B.query(st[x]+1,en[x]) << "\n";
		}
	}
}