We will use “node†interchangeably with “snowball.†Let’s start by representing the tree as an array. First, we can run a preorder traversal in O(N) time. Let st[x] denote the index (starting from one) of node x in the traversal and let en[x] denote the maximum index of any node in the subtree of v. Then the subtree of x corresponds exactly with all nodes with indices in the range [st[x],en[x]].
For a fixed color c, call a node ``special" if it is colored c and its parent is not colored c. For any node x, let sub[x] denote the number of nodes in the subtree of x. Then the number of nodes in its subtree that are colored c is given by one of the following:
We can rewrite the answer for a query for the subtree of x as the sum of two separate parts.
Part 1: getting (# of special nodes above or equal to x)
Whenever we add a special node, use a binary indexed tree (BIT) to add 1 to all nodes in the range [st[x],en[x]]. Then evaluating this quantity is equivalent to making a point query at st[x].
Part 2: getting ∑(subtree sizes of special nodes below x)
Whenever we add a special node y, use a BIT to add sub[y] to the index st[y]. Then we simply need to query the sum of all values in the BIT in the range [st[x]+1,en[x]].
Since we make O(Q) updates to the sets and the two BIT's, our solution runs in O(N+QlogN). My code follows.
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef vector<int> vi; #define FOR(i,a,b) for (int i = (a); i < (b); ++i) #define F0R(i,a) FOR(i,0,a) #define ROF(i,a,b) for (int i = (b)-1; i >= (a); --i) #define R0F(i,a) ROF(i,0,a) #define trav(a,x) for (auto& a: x) #define pb push_back #define ub upper_bound #define s second void setIO(string name) { ios_base::sync_with_stdio(0); cin.tie(0); freopen((name+".in").c_str(),"r",stdin); freopen((name+".out").c_str(),"w",stdout); } const int MX = 100005; template<class T, int SZ> struct BIT { T bit[SZ+1]; void upd(int pos, T x) { for (; pos <= SZ; pos += (pos&-pos)) bit[pos] += x; } T sum(int r) { T res = 0; for (; r; r -= (r&-r)) res += bit[r]; return res; } T query(int l, int r) { return sum(r)-sum(l-1); } }; BIT<ll,MX> A,B; map<int,int> col[MX]; int st[MX], en[MX],sub[MX]; int N,Q; vi adj[MX]; int co; void dfs(int x, int y) { st[x] = ++co; trav(t,adj[x]) if (t != y) dfs(t,x); en[x] = co; sub[x] = en[x]-st[x]+1; } void upd(int x, int y) { A.upd(st[x],y); A.upd(en[x]+1,-y); B.upd(st[x],y*sub[x]); } int main() { setIO("snowcow"); cin >> N >> Q; F0R(i,N-1) { int a,b; cin >> a >> b; adj[a].pb(b), adj[b].pb(a); } dfs(1,0); F0R(i,Q) { int t; cin >> t; if (t == 1) { int x,c; cin >> x >> c; auto it = col[c].ub(st[x]); if (it != begin(col[c]) && en[prev(it)->s] >= en[x]) continue; while (it != end(col[c]) && en[it->s] <= en[x]) { upd(it->s,-1); col[c].erase(it++); } col[c][st[x]] = x; upd(x,1); } else { int x; cin >> x; cout << sub[x]*A.sum(st[x])+B.query(st[x]+1,en[x]) << "\n"; } } }