# # Disjoint-set data structure - Library (Python) # # Copyright (c) 2021 Project Nayuki. (MIT License) # https://www.nayuki.io/page/disjoint-set-data-structure # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of # the Software, and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # - The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # - The Software is provided "as is", without warranty of any kind, express or # implied, including but not limited to the warranties of merchantability, # fitness for a particular purpose and noninfringement. In no event shall the # authors or copyright holders be liable for any claim, damages or other # liability, whether in an action of contract, tort or otherwise, arising from, # out of or in connection with the Software or the use or other dealings in the # Software. # import itertools from typing import List # Represents a set of disjoint sets. Also known as the union-find data structure. # Main operations are querying if two elements are in the same set, and merging two sets together. # Useful for testing graph connectivity, and is used in Kruskal's algorithm. class DisjointSet: num_sets: int parents: List[int] sizes: List[int] # Constructs a new set containing the given number of singleton sets. # For example, DisjointSet(3) --> {{0}, {1}, {2}}. def __init__(self, numelems: int): if numelems < 0: raise ValueError("Number of elements must be non-negative") # A global property self.num_sets = 0 # Per-node properties (two): # The index of the parent element. An element is a representative iff its parent is itself. self.parents = [] # Positive number if the element is a representative, otherwise zero. self.sizes = [] for _ in range(numelems): self.add_set() # Returns the number of elements among the set of disjoint sets. All the other methods # require the argument elemindex to satisfy 0 <= elemindex < get_num_elements(). def get_num_elements(self) -> int: return len(self.parents) # Returns the number of disjoint sets overall. 0 <= result <= get_num_elements(). def get_num_sets(self) -> int: return self.num_sets # (Private) Returns the representative element for the set containing the given element. This method is also # known as "find" in the literature. Also performs path compression, which alters the internal state to # improve the speed of future queries, but has no externally visible effect on the values returned. def _get_repr(self, elemindex: int) -> int: if not (0 <= elemindex < len(self.parents)): raise IndexError() # Follow parent pointers until we reach a representative parent: int = self.parents[elemindex] while True: grandparent: int = self.parents[parent] if grandparent == parent: return parent self.parents[elemindex] = grandparent # Partial path compression elemindex = parent parent = grandparent # Returns the size of the set that the given element is a member of. 1 <= result <= get_num_elements(). def get_size_of_set(self, elemindex: int) -> int: return self.sizes[self._get_repr(elemindex)] # Tests whether the given two elements are members of the same set. Note that the arguments are orderless. def are_in_same_set(self, elemindex0: int, elemindex1: int) -> bool: return self._get_repr(elemindex0) == self._get_repr(elemindex1) # Adds a new singleton set, incrementing get_num_elements() and get_num_sets(). # Returns the identity of the new element, which equals the old value of get_num_elements(). def add_set(self) -> int: elemindex = self.get_num_elements() self.parents.append(elemindex) self.sizes.append(1) self.num_sets += 1 return elemindex # Merges together the sets that the given two elements belong to. This method is also known as "union" in the literature. # If the two elements belong to different sets, then the two sets are merged and the method returns True. # Otherwise they belong in the same set, nothing is changed and the method returns False. Note that the arguments are orderless. def merge_sets(self, elemindex0: int, elemindex1: int) -> bool: # Get representatives repr0: int = self._get_repr(elemindex0) repr1: int = self._get_repr(elemindex1) if repr0 == repr1: return False # Compare sizes to choose parent node if self.sizes[repr0] < self.sizes[repr1]: repr0, repr1 = repr1, repr0 # Now repr0's size >= repr1's size # Graft repr1's subtree onto node repr0 self.parents[repr1] = repr0 self.sizes[repr0] += self.sizes[repr1] self.sizes[repr1] = 0 self.num_sets -= 1 return True # For unit tests. This detects many but not all invalid data structures, raising an AssertionError # if a structural invariant is known to be violated. This always returns silently on a valid object. def check_structure(self) -> None: numrepr: int = 0 sizesum: int = 0 for (i, parent, size) in zip( itertools.count(), self.parents, self.sizes): isrepr: bool = parent == i if isrepr: numrepr += 1 ok: bool = True ok &= 0 <= parent < len(self.parents) ok &= ((not isrepr) and size == 0) or (isrepr and 1 <= size <= len(self.parents)) if not ok: raise AssertionError() sizesum += size if not (0 <= self.num_sets == numrepr <= len(self.parents) == sizesum): raise AssertionError()