CSES - Shared codeLink to this code:
https://cses.fi/paste/01ea1fa1bcc55dd43c9d4f/
#include <algorithm>
#include <fstream>
#include <iostream>
#include <set>
#include <vector>
using namespace std;
template<typename T> ostream& operator<<(ostream &os, const vector<T> &v) { for (const auto &x : v) os << x << " "; return os; }
int n;
vector<vector<int>> graph;
vector<int> color, distinct;
vector<set<int>> subtree;
void dfs(int u, int parent) {
int largest = -1;
vector<int> children;
for (int v : graph[u]) {
if (v!=parent) {
dfs(v, u);
children.push_back(v);
if (largest==-1 || subtree[largest].size()<subtree[v].size())
largest = v;
}
}
if (largest!=-1)
swap(subtree[u], subtree[largest]);
subtree[u].insert(color[u]);
for (int child : children) {
if (child == largest)
continue;
subtree[u].insert(subtree[child].begin(), subtree[child].end());
//subtree[child].clear();
}
distinct[u] = subtree[u].size();
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
#ifndef ONLINE_JUDGE
ifstream in("input13.txt");
cin.rdbuf(in.rdbuf()); //redirect cin to input.txt!
ofstream out("out13.txt");
cout.rdbuf(out.rdbuf()); //redirect cout to out.txt!
#endif
cin >> n;
graph.resize(n);
color.resize(n);
distinct.resize(n);
subtree.resize(n);
for (int i=0; i<n; i++) {
cin >> color[i];
}
for(int i=0; i<n-1; i++) {
int u,v;
cin >> u >> v;
u--; v--;
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs(0, -1);
cout << distinct << endl;
return 0;
}