Source code for arsenal.datastructures.unionfind

"""UnionFind.py

Union-find data structure. Based on Josiah Carlson's code,
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
with significant additional changes by D. Eppstein.
"""

from collections import defaultdict

[docs]class UnionFind: """ Union-Find data structure. Each `UnionFind` instance `X` maintains a family of disjoint sets of hashable objects, supporting the following two methods: - `X[item]` returns a name for the set containing the given item. Each set is named by an arbitrarily-chosen one of its members; as long as the set remains unchanged it will keep the same name. If the item is not yet part of a set in `X`, a new singleton set is created for it. - `X.union(item1, item2, ...)` merges the sets containing each item into a single larger set. If any item is not yet part of a set in `X`, it is added to `X` as one of the members of the merged set. """ def __init__(self, elements=None): """Create a new empty union-find structure.""" self.weights = {} self.parents = {} if elements is not None: for x in elements: self.add(x) @property def elems(self): return self.parents
[docs] def add(self, x): "Add element x as a singleton" self.union(x, x)
[docs] def connected(self, x, y): return self[x] == self[y]
def __getitem__(self, obj): "Find and return the name of the set containing the object." # check for previously unknown object if obj not in self.parents: self.parents[obj] = obj self.weights[obj] = 1 return obj # find path of objects leading to the root path = [obj] root = self.parents[obj] while root != path[-1]: path.append(root) root = self.parents[root] # compress the path and return for ancestor in path: self.parents[ancestor] = root return root def __iter__(self): "Iterate through all items ever found or unioned by this structure." return iter(self.parents)
[docs] def union(self, *objects): "Find the sets containing the objects and merge them all." roots = [self[x] for x in objects] heaviest = max(roots, key = self.weights.__getitem__) for r in roots: if r != heaviest: self.weights[heaviest] += self.weights[r] self.parents[r] = heaviest
[docs] def roots(self): for x in self: if self[x] == x: yield x
[docs] def classes(self): classes = defaultdict(set) for x in self: root = self[x] # does path compression as a side effect classes[root].add(x) return classes.values()
[docs] def class_of(self, x): root = self[x] return [y for y in self if self[y] == root]