Reputation: 57
I am having trouble with the following problem in my data structures course. The errors provided by the course are rather ambiguous and I am not able to discern where the bug lies.
NOTE: The error message only says "Wrong answer." and the test cases are not provided.
Input Format: The first line of the input contains two integers š and š ā the number of tables in the database and the number of merge queries to perform, respectively. The second line of the input contains š integers š[š] ā the number of rows in the š-th table. Then the following š lines describe the merge queries. Each of them contains two integers ššš š”šššš”ššn[i] and š šš¢ššš[i] ā the numbers of the tables to merge.
Output Format: For each query print a line containing a single integer ā the maximum of the sizes of all tables (in terms of the number of rows) after the corresponding operation.
Sample Input:
5 5
1 1 1 1 1
3 5
2 4
1 4
5 4
5 3
Sample Output:
2
2
3
5
5
This is my current code and it works for most cases but there seems to be some edge cases that I have not accounted for.
class DataBases:
def __init__(self, row_counts):
self.max_row_count = max(row_counts)
self.row_counts = row_counts
n_tables = len(row_counts)
self.parent = list(range(n_tables))
self.rank = [1] * n_tables
def get_parent(self, table):
update_root = []
root = table
while root != self.parent[root]:
update_root.append(self.parent[root])
root = self.parent[root]
for i in update_root:
self.parent[i] = root
return root
def merge_tables(self, dst, src):
src_parent = self.get_parent(src)
dst_parent = self.get_parent(dst)
if src_parent == dst_parent: return
if self.rank[src_parent] > self.rank[dst_parent]:
self.parent[dst_parent] = src_parent
self.update_row_counts(src_parent, dst_parent)
else:
self.parent[src_parent] = dst_parent
self.update_row_counts(dst_parent, src_parent)
if self.rank[src_parent] == self.rank[dst_parent]:
self.rank[dst_parent] += 1
def update_row_counts(self, root, child):
self.row_counts[root] += self.row_counts[child]
self.row_counts[child] = 0
self.max_row_count = max(self.max_row_count, self.row_counts[root])
def main():
n_tables, n_queries = map(int, input().split())
counts = list(map(int, input().split()))
assert(n_tables == len(counts))
db = DataBases(counts)
for i in range(n_queries):
dst, src = map(int, input().split())
db.merge_tables(dst - 1, src - 1)
print(db.max_row_count)
if __name__ == "__main__":
main()
Upvotes: 0
Views: 32
Reputation: 57
The issue was in the get_parent (path compression) implementation.
Correct Solution:
def get_parent(self, table):
if table != self.parent[table]:
self.parent[table] = self.get_parent(self.parent[table])
return self.parent[table]
Upvotes: 0