Skip to content

Commit 2545e0e

Browse files
authored
Refactor UnionFind (#729)
* ref: refactor UnionFind * chore: replace `if else` by `match`
1 parent 46a5055 commit 2545e0e

File tree

1 file changed

+143
-134
lines changed

1 file changed

+143
-134
lines changed

src/data_structures/union_find.rs

+143-134
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
1+
//! A Union-Find (Disjoint Set) data structure implementation in Rust.
2+
//!
3+
//! The Union-Find data structure keeps track of elements partitioned into
4+
//! disjoint (non-overlapping) sets.
5+
//! It provides near-constant-time operations to add new sets, to find the
6+
//! representative of a set, and to merge sets.
7+
8+
use std::cmp::Ordering;
19
use std::collections::HashMap;
210
use std::fmt::Debug;
311
use std::hash::Hash;
412

5-
/// UnionFind data structure
6-
/// It acts by holding an array of pointers to parents, together with the size of each subset
713
#[derive(Debug)]
814
pub struct UnionFind<T: Debug + Eq + Hash> {
9-
payloads: HashMap<T, usize>, // we are going to manipulate indices to parent, thus `usize`. We need a map to associate a value to its index in the parent links array
10-
parent_links: Vec<usize>, // holds the relationship between an item and its parent. The root of a set is denoted by parent_links[i] == i
11-
sizes: Vec<usize>, // holds the size
12-
count: usize,
15+
payloads: HashMap<T, usize>, // Maps values to their indices in the parent_links array.
16+
parent_links: Vec<usize>, // Holds the parent pointers; root elements are their own parents.
17+
sizes: Vec<usize>, // Holds the sizes of the sets.
18+
count: usize, // Number of disjoint sets.
1319
}
1420

1521
impl<T: Debug + Eq + Hash> UnionFind<T> {
16-
/// Creates an empty Union Find structure with capacity n
17-
///
18-
/// # Examples
19-
///
20-
/// ```
21-
/// use the_algorithms_rust::data_structures::UnionFind;
22-
/// let uf = UnionFind::<&str>::with_capacity(5);
23-
/// assert_eq!(0, uf.count())
24-
/// ```
22+
/// Creates an empty Union-Find structure with a specified capacity.
2523
pub fn with_capacity(capacity: usize) -> Self {
2624
Self {
2725
parent_links: Vec::with_capacity(capacity),
@@ -31,7 +29,7 @@ impl<T: Debug + Eq + Hash> UnionFind<T> {
3129
}
3230
}
3331

34-
/// Inserts a new item (disjoint) in the data structure
32+
/// Inserts a new item (disjoint set) into the data structure.
3533
pub fn insert(&mut self, item: T) {
3634
let key = self.payloads.len();
3735
self.parent_links.push(key);
@@ -40,107 +38,63 @@ impl<T: Debug + Eq + Hash> UnionFind<T> {
4038
self.count += 1;
4139
}
4240

43-
pub fn id(&self, value: &T) -> Option<usize> {
44-
self.payloads.get(value).copied()
41+
/// Returns the root index of the set containing the given value, or `None` if it doesn't exist.
42+
pub fn find(&mut self, value: &T) -> Option<usize> {
43+
self.payloads
44+
.get(value)
45+
.copied()
46+
.map(|key| self.find_by_key(key))
4547
}
4648

47-
/// Returns the key of an item stored in the data structure or None if it doesn't exist
48-
fn find(&self, value: &T) -> Option<usize> {
49-
self.id(value).map(|id| self.find_by_key(id))
50-
}
51-
52-
/// Creates a link between value_1 and value_2
53-
/// returns None if either value_1 or value_2 hasn't been inserted in the data structure first
54-
/// returns Some(true) if two disjoint sets have been merged
55-
/// returns Some(false) if both elements already were belonging to the same set
56-
///
57-
/// #_Examples:
58-
///
59-
/// ```
60-
/// use the_algorithms_rust::data_structures::UnionFind;
61-
/// let mut uf = UnionFind::with_capacity(2);
62-
/// uf.insert("A");
63-
/// uf.insert("B");
64-
///
65-
/// assert_eq!(None, uf.union(&"A", &"C"));
66-
///
67-
/// assert_eq!(2, uf.count());
68-
/// assert_eq!(Some(true), uf.union(&"A", &"B"));
69-
/// assert_eq!(1, uf.count());
70-
///
71-
/// assert_eq!(Some(false), uf.union(&"A", &"B"));
72-
/// ```
73-
pub fn union(&mut self, item1: &T, item2: &T) -> Option<bool> {
74-
match (self.find(item1), self.find(item2)) {
75-
(Some(k1), Some(k2)) => Some(self.union_by_key(k1, k2)),
49+
/// Unites the sets containing the two given values. Returns:
50+
/// - `None` if either value hasn't been inserted,
51+
/// - `Some(true)` if two disjoint sets have been merged,
52+
/// - `Some(false)` if both elements were already in the same set.
53+
pub fn union(&mut self, first_item: &T, sec_item: &T) -> Option<bool> {
54+
let (first_root, sec_root) = (self.find(first_item), self.find(sec_item));
55+
match (first_root, sec_root) {
56+
(Some(first_root), Some(sec_root)) => Some(self.union_by_key(first_root, sec_root)),
7657
_ => None,
7758
}
7859
}
7960

80-
/// Returns the parent of the element given its id
81-
fn find_by_key(&self, key: usize) -> usize {
82-
let mut id = key;
83-
while id != self.parent_links[id] {
84-
id = self.parent_links[id];
61+
/// Finds the root of the set containing the element with the given index.
62+
fn find_by_key(&mut self, key: usize) -> usize {
63+
if self.parent_links[key] != key {
64+
self.parent_links[key] = self.find_by_key(self.parent_links[key]);
8565
}
86-
id
66+
self.parent_links[key]
8767
}
8868

89-
/// Unions the sets containing id1 and id2
90-
fn union_by_key(&mut self, key1: usize, key2: usize) -> bool {
91-
let root1 = self.find_by_key(key1);
92-
let root2 = self.find_by_key(key2);
93-
if root1 == root2 {
94-
return false; // they belong to the same set already, no-op
69+
/// Unites the sets containing the two elements identified by their indices.
70+
fn union_by_key(&mut self, first_key: usize, sec_key: usize) -> bool {
71+
let (first_root, sec_root) = (self.find_by_key(first_key), self.find_by_key(sec_key));
72+
73+
if first_root == sec_root {
74+
return false;
9575
}
96-
// Attach the smaller set to the larger one
97-
if self.sizes[root1] < self.sizes[root2] {
98-
self.parent_links[root1] = root2;
99-
self.sizes[root2] += self.sizes[root1];
100-
} else {
101-
self.parent_links[root2] = root1;
102-
self.sizes[root1] += self.sizes[root2];
76+
77+
match self.sizes[first_root].cmp(&self.sizes[sec_root]) {
78+
Ordering::Less => {
79+
self.parent_links[first_root] = sec_root;
80+
self.sizes[sec_root] += self.sizes[first_root];
81+
}
82+
_ => {
83+
self.parent_links[sec_root] = first_root;
84+
self.sizes[first_root] += self.sizes[sec_root];
85+
}
10386
}
104-
self.count -= 1; // we had 2 disjoint sets, now merged as one
87+
88+
self.count -= 1;
10589
true
10690
}
10791

108-
/// Checks if two items belong to the same set
109-
///
110-
/// #_Examples:
111-
///
112-
/// ```
113-
/// use the_algorithms_rust::data_structures::UnionFind;
114-
/// let mut uf = UnionFind::from_iter(["A", "B"]);
115-
/// assert!(!uf.is_same_set(&"A", &"B"));
116-
///
117-
/// uf.union(&"A", &"B");
118-
/// assert!(uf.is_same_set(&"A", &"B"));
119-
///
120-
/// assert!(!uf.is_same_set(&"A", &"C"));
121-
/// ```
122-
pub fn is_same_set(&self, item1: &T, item2: &T) -> bool {
123-
matches!((self.find(item1), self.find(item2)), (Some(root1), Some(root2)) if root1 == root2)
92+
/// Checks if two items belong to the same set.
93+
pub fn is_same_set(&mut self, first_item: &T, sec_item: &T) -> bool {
94+
matches!((self.find(first_item), self.find(sec_item)), (Some(first_root), Some(sec_root)) if first_root == sec_root)
12495
}
12596

126-
/// Returns the number of disjoint sets
127-
///
128-
/// # Examples
129-
///
130-
/// ```
131-
/// use the_algorithms_rust::data_structures::UnionFind;
132-
/// let mut uf = UnionFind::with_capacity(5);
133-
/// assert_eq!(0, uf.count());
134-
///
135-
/// uf.insert("A");
136-
/// assert_eq!(1, uf.count());
137-
///
138-
/// uf.insert("B");
139-
/// assert_eq!(2, uf.count());
140-
///
141-
/// uf.union(&"A", &"B");
142-
/// assert_eq!(1, uf.count())
143-
/// ```
97+
/// Returns the number of disjoint sets.
14498
pub fn count(&self) -> usize {
14599
self.count
146100
}
@@ -158,11 +112,11 @@ impl<T: Debug + Eq + Hash> Default for UnionFind<T> {
158112
}
159113

160114
impl<T: Debug + Eq + Hash> FromIterator<T> for UnionFind<T> {
161-
/// Creates a new UnionFind data structure from an iterable of disjoint elements
115+
/// Creates a new UnionFind data structure from an iterable of disjoint elements.
162116
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
163117
let mut uf = UnionFind::default();
164-
for i in iter {
165-
uf.insert(i);
118+
for item in iter {
119+
uf.insert(item);
166120
}
167121
uf
168122
}
@@ -175,45 +129,100 @@ mod tests {
175129
#[test]
176130
fn test_union_find() {
177131
let mut uf = UnionFind::from_iter(0..10);
178-
assert_eq!(uf.find_by_key(0), 0);
179-
assert_eq!(uf.find_by_key(1), 1);
180-
assert_eq!(uf.find_by_key(2), 2);
181-
assert_eq!(uf.find_by_key(3), 3);
182-
assert_eq!(uf.find_by_key(4), 4);
183-
assert_eq!(uf.find_by_key(5), 5);
184-
assert_eq!(uf.find_by_key(6), 6);
185-
assert_eq!(uf.find_by_key(7), 7);
186-
assert_eq!(uf.find_by_key(8), 8);
187-
assert_eq!(uf.find_by_key(9), 9);
188-
189-
assert_eq!(Some(true), uf.union(&0, &1));
190-
assert_eq!(Some(true), uf.union(&1, &2));
191-
assert_eq!(Some(true), uf.union(&2, &3));
132+
assert_eq!(uf.find(&0), Some(0));
133+
assert_eq!(uf.find(&1), Some(1));
134+
assert_eq!(uf.find(&2), Some(2));
135+
assert_eq!(uf.find(&3), Some(3));
136+
assert_eq!(uf.find(&4), Some(4));
137+
assert_eq!(uf.find(&5), Some(5));
138+
assert_eq!(uf.find(&6), Some(6));
139+
assert_eq!(uf.find(&7), Some(7));
140+
assert_eq!(uf.find(&8), Some(8));
141+
assert_eq!(uf.find(&9), Some(9));
142+
143+
assert!(!uf.is_same_set(&0, &1));
144+
assert!(!uf.is_same_set(&2, &9));
145+
assert_eq!(uf.count(), 10);
146+
147+
assert_eq!(uf.union(&0, &1), Some(true));
148+
assert_eq!(uf.union(&1, &2), Some(true));
149+
assert_eq!(uf.union(&2, &3), Some(true));
150+
assert_eq!(uf.union(&0, &2), Some(false));
151+
assert_eq!(uf.union(&4, &5), Some(true));
152+
assert_eq!(uf.union(&5, &6), Some(true));
153+
assert_eq!(uf.union(&6, &7), Some(true));
154+
assert_eq!(uf.union(&7, &8), Some(true));
155+
assert_eq!(uf.union(&8, &9), Some(true));
156+
assert_eq!(uf.union(&7, &9), Some(false));
157+
158+
assert_ne!(uf.find(&0), uf.find(&9));
159+
assert_eq!(uf.find(&0), uf.find(&3));
160+
assert_eq!(uf.find(&4), uf.find(&9));
161+
assert!(uf.is_same_set(&0, &3));
162+
assert!(uf.is_same_set(&4, &9));
163+
assert!(!uf.is_same_set(&0, &9));
164+
assert_eq!(uf.count(), 2);
165+
192166
assert_eq!(Some(true), uf.union(&3, &4));
193-
assert_eq!(Some(true), uf.union(&4, &5));
194-
assert_eq!(Some(true), uf.union(&5, &6));
195-
assert_eq!(Some(true), uf.union(&6, &7));
196-
assert_eq!(Some(true), uf.union(&7, &8));
197-
assert_eq!(Some(true), uf.union(&8, &9));
198-
assert_eq!(Some(false), uf.union(&9, &0));
199-
200-
assert_eq!(1, uf.count());
167+
assert_eq!(uf.find(&0), uf.find(&9));
168+
assert_eq!(uf.count(), 1);
169+
assert!(uf.is_same_set(&0, &9));
170+
171+
assert_eq!(None, uf.union(&0, &11));
201172
}
202173

203174
#[test]
204175
fn test_spanning_tree() {
205-
// Let's imagine the following topology:
206-
// A <-> B
207-
// B <-> C
208-
// A <-> D
209-
// E
210-
// F <-> G
211-
// We have 3 disjoint sets: {A, B, C, D}, {E}, {F, G}
212176
let mut uf = UnionFind::from_iter(["A", "B", "C", "D", "E", "F", "G"]);
213177
uf.union(&"A", &"B");
214178
uf.union(&"B", &"C");
215179
uf.union(&"A", &"D");
216180
uf.union(&"F", &"G");
217-
assert_eq!(3, uf.count());
181+
182+
assert_eq!(None, uf.union(&"A", &"W"));
183+
184+
assert_eq!(uf.find(&"A"), uf.find(&"B"));
185+
assert_eq!(uf.find(&"A"), uf.find(&"C"));
186+
assert_eq!(uf.find(&"B"), uf.find(&"D"));
187+
assert_ne!(uf.find(&"A"), uf.find(&"E"));
188+
assert_ne!(uf.find(&"A"), uf.find(&"F"));
189+
assert_eq!(uf.find(&"G"), uf.find(&"F"));
190+
assert_ne!(uf.find(&"G"), uf.find(&"E"));
191+
192+
assert!(uf.is_same_set(&"A", &"B"));
193+
assert!(uf.is_same_set(&"A", &"C"));
194+
assert!(uf.is_same_set(&"B", &"D"));
195+
assert!(!uf.is_same_set(&"B", &"F"));
196+
assert!(!uf.is_same_set(&"E", &"A"));
197+
assert!(!uf.is_same_set(&"E", &"G"));
198+
assert_eq!(uf.count(), 3);
199+
}
200+
201+
#[test]
202+
fn test_with_capacity() {
203+
let mut uf: UnionFind<i32> = UnionFind::with_capacity(5);
204+
uf.insert(0);
205+
uf.insert(1);
206+
uf.insert(2);
207+
uf.insert(3);
208+
uf.insert(4);
209+
210+
assert_eq!(uf.count(), 5);
211+
212+
assert_eq!(uf.union(&0, &1), Some(true));
213+
assert!(uf.is_same_set(&0, &1));
214+
assert_eq!(uf.count(), 4);
215+
216+
assert_eq!(uf.union(&2, &3), Some(true));
217+
assert!(uf.is_same_set(&2, &3));
218+
assert_eq!(uf.count(), 3);
219+
220+
assert_eq!(uf.union(&0, &2), Some(true));
221+
assert!(uf.is_same_set(&0, &1));
222+
assert!(uf.is_same_set(&2, &3));
223+
assert!(uf.is_same_set(&0, &3));
224+
assert_eq!(uf.count(), 2);
225+
226+
assert_eq!(None, uf.union(&0, &10));
218227
}
219228
}

0 commit comments

Comments
 (0)