Union Find | Disjoint Set Union Data Structure Introduction

Prereq: Depth First Search Review

Once we have a strong grasp of recursion and Depth First Search we now introduce Disjoint Set Union (DSU).

This data structure is motivated by the following problem. Suppose we have sets of elements and we are asked to check if a certain element belongs to a particular set. In addition, we want our data structure to support updates as well through merging two sets into one set.

Our data structure must support the following operations

  1. query for the Set ID of a given element (find operation)
  2. merge two disjoint sets into one set (union operation)

Note that the set ID is an unique identifier that helps us identify the disjoint sets.

Our goal is to construct a data structure that handles both the merge and query operations efficiently in O(1) time. A very straight forward implementation is to use a list of hashsets to store the disjoint sets, however, merging two hashsets takes O(n) time. To improve the runtime of the merging operation, we need to consider data structures that merges in O(1) time. We can easily merge two trees together in O(1) time if there is no restriction on how many children a node has, so we could employ a tree-like structure.

We can imagine the disjoint set data structure as a series of trees, where each tree denotes a set, such that an element in a tree (set) belongs solely to that tree (set) and no other trees (sets). The following graphic illustrates this idea.

Disjoint Union Set basics

Implementation

How do we construct the tree strcture to maintain disjoint sets?

We nominate a particular node to be the the root of the tree, and this node will act as an identifier for all nodes within the tree. This identifier is the assigned set ID. We know that if two nodes share the same root, they must belong to the same set. Furthermore, if they don't, they belong in different (disjoint) sets. We can easily assign a parent to a node by using a hashmap, where the key is the node and the value is its parent. Initially, we set every node's parent to itself as every node is in a set by itself. Then, we can merge two sets by setting one node's parent to the other node's parent, as this joins the two nodes into one tree. This way, we can find the set ID of a node by recursively moving up the tree until we reach the root (the node whose parent is itself). The following code accomplishes a Find operation that has best case O(1), average case O(log(n)) since we have a randomized trees which have average depth O(log(n)) and a worst case of O(n) for a maximum depth tree.

Let's see an example on how merging is performed. Given 6 disjoint nodes, perform operations: merge 3 1, merge 1 0, merge 5 4, merge 2 0. Recall that this will be done in our program with the union function.

1public static class UnionFind<T> {
2    // initialize the data structure that maps the node to its set ID
3    private HashMap<T, T> id = new HashMap<>();
4
5    // find the set ID of Node x
6    public T find(T x) {
7        // get the value associated with key x, if it's not in the map return x
8        T y = id.getOrDefault(x, x);
9        // check if the current node is a Set ID node
10        if (y != x) {
11            // set the value to Set ID node of node y
12            y = find(y);
13        }
14        return y;
15    }
16
17    // union two different sets setting one Set's parent to the other parent
18    public void union(T x, T y) {
19        id.put(find(x), find(y));
20    }
21}
22
1class UnionFind:
2    # initialize the data structure that maps the node to its set ID
3    def __init__(self):
4        self.id = {}
5
6    # find the Set ID of Node x
7    def find(self, x):
8        # get the value associated with key x, if it's not in the map return x
9        y = self.id.get(x, x)
10        # check if the current node is a Set ID node
11        if y != x:
12            # set the value to Set ID node of node y
13            y = self.find(y)
14        return y
15
16    # union two different sets setting one Set's parent to the other parent
17    def union(self, x, y):
18        self.id[self.find(x)] = self.find(y)
19
1class UnionFind {
2    // initialize the data structure that maps the node to its set ID
3    constructor() {
4        this.id = new Map();
5    }
6
7    // find the Set ID of Node x
8    find(x) {
9        let y = this.id.has(x) ? this.id.get(x) : x;
10        // check if the current node is a Set ID node
11        if (y !== x) {
12            y = this.find(y);
13        }
14        return y;
15    }
16
17    // union two different sets setting one Set's parent to the other parent
18    union(x, y) {
19        this.id.get(this.find(x)) = this.find(y);
20    }
21}
22
1template <class T> class UnionFind {
2private:
3    std::unordered_map<T, T> id;
4
5public:
6    T find(T x) {
7        T y = id.count(x) ? id[x] : x;
8        if (y != x) {
9            y = find(y);
10            id[x] = y;
11        }
12        return y;
13    }
14
15    // add trailing _ to avoid name collision with C++ keyword "union"
16    void union_(T x, T y) {
17        id[find(x)] = find(y);
18    }
19};
20
1public class UnionFind<T> {
2    private Dictionary<T, T> id = new Dictionary<T, T>();
3
4    // find the Set ID of Node x
5    public T Find(T x) {
6        T y = id.GetValueOrDefault(x, x);
7        //check if the current node is a Set ID node
8        if (!y.Equals(x)) {
9            y = Find(y);
10        }
11        return y;
12    }
13
14    // union two different sets setting one Set's parent to the other parent
15    public void Union(T x, T y) {
16        id[Find(x)] = Find(y);
17    }
18}
19

Disjoint Set Union is also called Union Find because of its two operations - union and find.

Tree Compression Optimization

Now that we have a general idea of the data structure and how it is implemented, let's introduce an optimization. Imagine a scenario where our tree is not particularly balanced. In this case, the find operation may be quite slow (recursion depth is deep). For each node, if we could shorten the path to the root, then the runtime of the next find operation is drastically cut down. We can do this by setting the parent of each node in this path to the root directly. We retrieve the Set ID when we reach the root, then as we return to a previous recursive stacks, we can set the parent of each node to the root node (same as Set ID value). After this restructure, the parent of all nodes in the same tree will be set to the root. Here is a graphic to demonstrate this idea and should be a good visual indication of why this technique is referred to as tree compression as we eventually reach a tree with depth of 2 after querying every node.

Assume we have a chain of 4 nodes like below. We will go through the details of the call find(3).

While we try to find the parent node of node 3, we set every node along the path to the Set ID of the root 0, which means later queries within our data structure only take O(2) time. This technique is called tree compression and allows us to achieve amortized O(log(n)) time complexity. As a reminder, amortized time complexity is referring to the time complexity over a large number of operations.

1public static class UnionFind<T> {
2    // initialize the data structure that maps the node to its set ID
3    private HashMap<T, T> id = new HashMap<>();
4
5    // find the Set ID of Node x
6    public T find(T x) {
7        // Get the value associated with key x, if it's not in the map return x
8        T y = id.getOrDefault(x, x);
9        // check if the current node is a Set ID node
10        if (y != x) {
11            // set the value to Set ID node of node y
12            y = find(y);
13            // change the hash value of node x to Set ID value of node y
14            id.put(x, y);
15        }
16        return y;
17    }
18
19    // union two different sets setting one Set's parent to the other parent
20    public void union(T x, T y) {
21        id.put(find(x), find(y));
22    }
23}
24
1class UnionFind:
2    # initialize the data structure that maps the node to its set ID
3    def __init__(self):
4        self.id = {}
5
6    # find the Set ID of Node x
7    def find(self, x):
8        y = self.id.get(x, x)
9        # check if the current node is a Set ID node
10        if y != x:
11            # set the value to Set ID node of node y, and change the hash value of node x to Set ID value of node y
12            self.id[x] = y = self.find(y)
13        return y
14
15    # union two different sets setting one Set's parent to the other parent
16    def union(self, x, y):
17        self.id[self.find(x)] = self.find(y)
18
1class UnionFind {
2    // initialize the data structure that maps the node to its set ID
3    constructor() {
4        this.id = new Map();
5    }
6
7    // find the Set ID of Node x
8    find(x) {
9        let y = this.id.has(x) ? this.id.get(x) : x;
10        // check if the current node is a Set ID node
11        if (y !== x) {
12            // set the value to Set ID node of node y
13            y = this.find(y);
14            // change the hash value of node x to Set ID value of node y
15            this.id.set(x, y);
16        }
17        return y;
18    }
19
20    // union two different sets setting one Set's parent to the other parent
21    union(x, y) {
22        this.id.set(this.find(x), this.find(y));
23    }
24}
25
1public class UnionFind<T> {
2    private Dictionary<T, T> id = new Dictionary<T, T>();
3
4    // find the Set ID of Node x
5    public T Find(T x) {
6        T y = id.GetValueOrDefault(x, x);
7        //check if the current node is a Set ID node
8        if (!y.Equals(x)) {
9            y = Find(y);
10            // change the hash value of node x to Set ID value of node y
11            id[x] = y;
12        }
13        return y;
14    }
15
16    // union two different sets setting one Set's parent to the other parent
17    public void Union(T x, T y) {
18        id[Find(x)] = Find(y);
19    }
20}
21

Union by Rank (Advanced, Optional)

Can we improve this even more though?

We have already discussed tree compression to optimize our future queries (amortized runtime), but there is actually a way to improve the time complexity. This optimization uses a technique called union by rank, where we assign ranks to our nodes. Here, the ranks represent the relative depths of our trees. Each time we merge 2 sets, we always set the parent of the node with the smaller rank to that of the larger rank and update the ranks. This technique improves our O(log(n)) algorithm to O(alpha(n)) where alpha(n) represents the inverse Ackermann Function which grows very slowly relative to n. The proof for the time complexity is a bit complicated so we will not touch upon it here. Note that since the inverse Ackermann Function grows so slowly, it will never practically speaking exceed O(4), making the runtime "constant". Precisely speaking though, the time complexity is defined by O(alpha(n)).

As a final note, in most cases using union by rank is likely not necessary, as O(log(n)) is likely to be sufficiently efficient, but this is still a good trick to know for cases where it is required. Now our final implementation, taking union by rank into account,

1public static class UnionFind<T> {
2    private HashMap<T, T> id = new HashMap<>();
3    private HashMap<T, Integer> rank = new HashMap<>();
4
5    // find the Set ID of Node x
6    public T find(T x) {
7        // Get the value associated with key x, if it's not in the map return x
8        T y = id.getOrDefault(x, x);
9        // check if the current node is a Set ID node
10        if (y != x) {
11            // set the value to Set ID node of node y
12            y = find(y);
13            // change the hash value of node x to Set ID value of node y
14            id.put(x, y);
15        }
16        return y;
17    }
18
19    // union two different sets setting one Set's parent to the other parent
20    public void union(T x, T y) {
21        // check if keys exist in our rank map; if not, add them
22        if (!rank.containsKey(find(x))) rank.put(find(x), 0);
23        if (!rank.containsKey(find(y))) rank.put(find(y), 0);
24        if (rank.get(find(x)) < rank.get(find(y))) {
25            id.put(find(x), find(y));
26        }
27        else {
28            id.put(find(y), find(x));
29            // if rank is the same then we update x rank and increment by 1
30            // we make y's parent equal to x's parent, so x has increased depth
31            if (rank.get(find(x)) == rank.get(find(y))) {
32                rank.put(find(x), rank.get(find(x)) + 1);
33            }
34        }
35    }
36}
37
1class UnionFind:
2    def __init__(self):
3        self.id = {}
4        self.rank = {}
5
6    # find the Set ID of Node x
7    def find(self, x):
8      # Get the value associated with key x, if it's not in the map return x
9      y = self.id.get(x, x)
10      # check if the current node is a Set ID node
11      if y != x:
12          # change the hash value of node x to Set ID value of node y
13          self.id[x] = y = self.find(y)
14      return y
15
16
17    # union two different sets setting one Set's parent to the other parent
18    def union(self, x, y):
19        # check if keys exist in our rank map; if not, add them
20        if self.find(x) not in self.rank:
21            self.rank[self.find(x)] = 0
22        if self.find(y) not in self.rank:
23            self.rank[self.find(y)] = 0
24        if self.rank[self.find(x)] < self.rank[self.find(y)]:
25            self.id[self.find(x)] = self.find(y)
26        else:
27            self.id[self.find(y)] = self.find(x)
28            # if rank is the same then we update x rank and increment by 1
29            # we make y's parent equal to x's parent, so x has increased depth
30            if self.rank[self.find(x)] == self.rank[self.find(y)]:
31                self.rank[self.find(x)] = self.rank[self.find(x)] + 1
32
1class UnionFind {
2    constructor() {
3        this.id = new Map();
4        this.rank = new Map();
5    }
6
7    // find the Set ID of Node x
8    find(x) {
9        // get the value associated with key x, if it's not in the map return x
10        let y = this.id.has(x) ? this.id.get(x) : x;
11        // check if the current node is a Set ID node
12        if (y !== x) {
13            // change the hash value of node x to Set ID value of node y
14            y = this.find(y);
15            this.id.set(x, y);
16        }
17        return y;
18    }
19
20    // union two different sets setting one Set's parent to the other parent
21    union(x, y) {
22        // check if keys exist in our rank map; if not, add them
23        if (!this.rank.has(this.find(x))) {
24            this.rank.set(this.find(x), 0);
25        }
26        if (!this.rank.has(this.find(y))) {
27            this.rank.set(this.find(y), 0);
28        }
29        if (this.rank.get(this.find(x)) < this.rank.get(this.find(y))) {
30            this.id.set(this.find(x), this.find(y));
31        } else {
32            this.id.set(this.find(y), this.find(x));
33            // if rank is the same then we update x rank and increment by 1
34            // we make y's parent equal to x's parent, so x has increased depth
35            if (this.rank.get(this.find(x)) == this.rank.get(this.find(y))) {
36                this.rank.set(this.find(x), this.rank.get(this.find(x)) + 1);
37            }
38        }
39    }
40}
41

When is Union Find useful?

Union find is especially useful in component related problems. It is often used to implement Kruskal's and Prim's algorithm to find the minimum spanning tree (MST) of a graph.

Do I Need to Know or Implement Union Find for the Interview?

It's slightly more advanced. For implementation, knowing the version with Tree Compression Optimization is enough. No interviewer would expect you to write union by rank in the short interview time.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫