DSA Crash Course
Advanced Data Structures

4.2 Union-Find / Disjoint Set

Mark when done:

Deep Dive: Path Compression Visualized

Beginner tip: Union-Find answers one question fast: "Are A and B in the same group?" Imagine students forming study groups. union(Alice, Bob) merges their groups. find(Alice) == find(Bob) checks if they are in the same group. Two optimizations (path compression + union by rank) make each operation nearly O(1).

Before and After Path Compression

Before find(5):              After find(5):

    1                            1
    |                         / | \ \
    2                        2  3  4  5
    |
    3
    |
    4
    |
    5

 find(5): 5->4->3->2->1
 After: 5->1, 4->1, 3->1 (all point directly to root)

Path compression makes every node point directly to the root during find(). Future lookups for 3, 4, or 5 are O(1).

Redundant Connection -- Detect Cycle

Edges: [[1,2], [1,3], [2,3]]. Which edge creates a cycle?

Process [1,2]: find(1)=1, find(2)=2. Different -> union. parent: 2->1
Process [1,3]: find(1)=1, find(3)=3. Different -> union. parent: 3->1
Process [2,3]: find(2)=1, find(3)=1. SAME ROOT! -> [2,3] is the redundant edge.

Key insight: If both endpoints of an edge already share the same root, adding this edge creates a cycle.

Why Union-Find over DFS/BFS?

Both can answer "are A and B connected?" -- but at very different costs when the queries keep coming.

Same problem, two solutions: count connected components in an undirected graph as edges stream in, with a connected(a, b) query after every edge.

1. The DFS approach -- rebuild every query

def connected_dfs(adj, a, b):
    seen = set()
    def dfs(u):
        seen.add(u)
        for v in adj[u]:
            if v not in seen:
                dfs(v)
    dfs(a)
    return b in seen

Each connected call walks the graph from scratch: O(V + E) per query. With Q queries on a graph of size N, total cost is O(Q * (V + E)). For N = 10^5 and Q = 10^5 queries, that's ~10^10 operations -- TLE territory.

2. The Union-Find approach -- build once, query forever

uf = UnionFind(n)
for a, b in edges:
    uf.union(a, b)
def connected_uf(a, b):
    return uf.find(a) == uf.find(b)

Each union and find is amortized O(α(n)) ≈ O(1) with path compression + union by rank. Total cost: O((E + Q) * α(N)) -- effectively linear. Same workload that took 10^10 ops finishes in ~2 * 10^5.

The takeaway. When the connectivity question is asked once, DFS/BFS is fine and arguably simpler. When queries arrive online (in a stream, interleaved with new edges), Union-Find wins by orders of magnitude because it never re-traverses the graph -- it just does pointer hops up a very flat tree. Any problem phrased as "after each edge / merge / friend request, tell me X about the groups" is a Union-Find problem.


Before You Start -- How to Think About Union-Find

Why It Exists

Need: Dynamically track which elements are in the same group (connected component). DFS/BFS approach: Rebuild traversal each time = expensive for repeated queries. Union-Find: find(x) and union(x, y) both in ~O(1) amortized.

The Mental Model

Think of it as a forest of trees.
Each element points to a parent. The ROOT of the tree = the group ID.

find(x): Follow parent pointers to the root.
union(x, y): Make one root point to the other.

Optimizations:
  Path compression: During find, make every node point directly to root.
  Union by rank: Attach the shorter tree under the taller one.
  Combined: nearly O(1) per operation.

When to Use Union-Find

1. "Number of connected components" -> Union-Find or DFS
2. "Does adding this edge create a cycle?" -> Union-Find (if find(a) == find(b), it's a cycle)
3. "Merge accounts / groups dynamically" -> Union-Find
4. "Is graph a valid tree?" -> Union-Find: n-1 edges + no cycles
5. Any problem with INCREMENTAL connectivity -> Union-Find > BFS/DFS

Union-Find Template

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # path compression
        return self.parent[x]
    
    def union(self, x, y):
        px, py = self.find(x), self.find(y)
        if px == py: return False  # already connected
        if self.rank[px] < self.rank[py]: px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]: self.rank[px] += 1
        return True

Problems

#ProblemStatus
1Number of Connected Components[ ]
2Redundant Connection[ ]
3Accounts Merge[ ]
4Longest Consecutive Sequence[ ]
5Graph Valid Tree[ ]
6Number of Provinces[ ]

Notes