Union Find
(union-find/union-find.hpp)
- View this file on GitHub
- Last update: 2025-10-23 01:57:19+09:00
- Include:
#include "union-find/union-find.hpp"
$n$ 頂点 $0$ 辺のグラフに対する以下のクエリを処理する.$\alpha$ はアッカーマン関数の逆関数.
-
find(v):頂点 $v$ を含む連結成分の代表頂点を返す.$O(\alpha(n))$ 時間. -
size(v):頂点 $v$ を含む連結成分の頂点数を返す.$O(\alpha(n))$ 時間. -
same(u, v):頂点 $u,v$ が同じ連結成分に含まれるかを返す.$O(\alpha(n))$ 時間. -
unite(u, v):辺 $uv$ を追加する.$O(\alpha(n))$ 時間.-
unite(u, v, f)で連結成分マージ時に $f$ を呼び出す.
-
-
groups():連結成分を列挙する.$O(n)$ 時間.
アルゴリズム
find および unite が実装できれば他も難しくない.
根付き木として考える. 各頂点の親を示す長さ $n$ の列 $(p_0,\dots,p_{n-1})$ を管理する. $p_i=i$ ならば頂点 $i$ は根であり,はじめ全ての $i$ に対し $p_i=i$ である.
find, unite は naive には以下のように実装できる.
def find(v):
if p[v] == v: return v
return find(p[v])
def unite(u, v):
u = find(u)
v = find(v)
p[v] = u
計算量
https://37zigen.com/union-find-complexity-1/
毎回素直に根まで辿り,unite では一方の根をもう一方の根に繋ぐ. これでは worst $\Theta(n)$ 時間である.
計算量を改善する手法がある.
path compression
path compression では find を以下のようにする.
def find(v):
if p[v] == v: return v
return p[v] = find(p[v])
一度再帰で通った部分を圧縮している.
このとき各操作がならし $O(\log n)$ 時間になる.
union by size (rank)
union by size では unite を以下のようにする.
def unite(u, v):
u = find(u)
v = find(v)
if size(u) < size(v): swap(u, v)
p[v] = u
部分木のサイズが小さい方を大きい方の子とする(いわゆるマージテク).
このとき各操作が $O(\log n)$ 時間になる.
各頂点の rank を (path compression しない状態での) その頂点を根とする部分木の高さとする. union by rank は union by size で size を用いている部分を rank に置き換えたもの.
組み合わせ
path compression と union by size (rank) を組み合わせるとさらにオーダーが改善し,ならし $O(\alpha(n))$ 時間になることが知られている.
ここで $\alpha$ はアッカーマン関数の逆関数で,現実的な $n$ の範囲で $\alpha(n)\leq 5$.
参考
Verified with
Code
#pragma once
struct UnionFind {
private:
vector<int> a;
public:
UnionFind(int n) : a(n, -1) {}
int find(int x) { return a[x] < 0 ? x : a[x] = find(a[x]); }
int size(int x) { return -a[find(x)]; }
bool same(int x, int y) { return find(x) == find(y); }
bool unite(int x, int y) {
x = find(x), y = find(y);
if (x == y) return false;
if (a[x] > a[y]) swap(x, y);
a[x] += a[y];
a[y] = x;
return true;
}
template <class F>
bool unite(int x, int y, F f) {
x = find(x), y = find(y);
if (x == y) return false;
if (a[x] > a[y]) swap(x, y);
a[x] += a[y];
a[y] = x;
f(x, y);
return true;
}
vector<vector<int>> groups() {
vector<int> root(a.size()), gsize(a.size());
for (int i = 0; i < a.size(); i++) gsize[root[i] = find(i)]++;
vector<vector<int>> res(a.size());
for (int i = 0; i < res.size(); i++) res[i].reserve(gsize[i]);
for (int i = 0; i < root.size(); i++) res[root[i]].push_back(i);
res.erase(remove_if(res.begin(), res.end(), [&](const vector<int>& v) { return v.empty(); }), res.end());
return res;
}
};
/**
* @brief Union Find
* @docs docs/union-find/union-find.md
*/#line 2 "union-find/union-find.hpp"
struct UnionFind {
private:
vector<int> a;
public:
UnionFind(int n) : a(n, -1) {}
int find(int x) { return a[x] < 0 ? x : a[x] = find(a[x]); }
int size(int x) { return -a[find(x)]; }
bool same(int x, int y) { return find(x) == find(y); }
bool unite(int x, int y) {
x = find(x), y = find(y);
if (x == y) return false;
if (a[x] > a[y]) swap(x, y);
a[x] += a[y];
a[y] = x;
return true;
}
template <class F>
bool unite(int x, int y, F f) {
x = find(x), y = find(y);
if (x == y) return false;
if (a[x] > a[y]) swap(x, y);
a[x] += a[y];
a[y] = x;
f(x, y);
return true;
}
vector<vector<int>> groups() {
vector<int> root(a.size()), gsize(a.size());
for (int i = 0; i < a.size(); i++) gsize[root[i] = find(i)]++;
vector<vector<int>> res(a.size());
for (int i = 0; i < res.size(); i++) res[i].reserve(gsize[i]);
for (int i = 0; i < root.size(); i++) res[root[i]].push_back(i);
res.erase(remove_if(res.begin(), res.end(), [&](const vector<int>& v) { return v.empty(); }), res.end());
return res;
}
};
/**
* @brief Union Find
* @docs docs/union-find/union-find.md
*/