teepee
teepee

Reputation: 2714

How to identify cases where both elements of a pair are greater than others' respective elements in the set?

I have a case where I have a list of pairs, each with two numerical values. I want to find the subset of these elements containing only those pairs that are not exceeded by both elements of another (let's say "eclipsed" by another).

For example, the pair (1,2) is eclipsed by (4,5) because both elements are less than the respective elements in the other pair.

Also, (1,2) is considered eclipsed by (1,3) because while the first element is equal to the other and the second element is less than the other's.

However the pair (2, 10) is not eclipsed by (9, 9) because only one of its elements is exceeded by the other's.

Cases where the pairs are identical should be reduced to just one (duplicates removed).

Ultimately, I am looking to reduce the list of pairs to a subset where only pairs that were not eclipsed by any others remain.

For example, take the following list:

(1,2)
(1,5)
(2,2)
(1,2)
(2,2)
(9,1)
(1,1)

This should be reduced to the following:

(1,5)
(2,2)
(9,1)

My initial implementation of this in python was the following, using polars:

import polars as pl

pairs_list = [
    (1,2),
    (1,5),
    (2,2),
    (1,2),
    (2,2),
    (9,1),
    (1,1),
]

# tabulate pair elements as 'a' and 'b'
pairs = pl.DataFrame(
    data=pairs_list,
    schema={'a': pl.UInt32, 'b': pl.UInt32},
    orient='row',
)

# eliminate any duplicate pairs
unique_pairs = pairs.unique()

# self join so every pair can be compared (except against itself)
comparison_inputs = (
    unique_pairs
    .join(
        unique_pairs,
        how='cross',
        suffix='_comp',
    )
    .filter(
        pl.any_horizontal(
            pl.col('a') != pl.col('a_comp'),
            pl.col('b') != pl.col('b_comp'),
        )
    )
)

# flag pairs that were eclipsed by others
comparison_results = (
    comparison_inputs
    .with_columns(
        pl.all_horizontal(
            pl.col('a') <= pl.col('a_comp'),
            pl.col('b') <= pl.col('b_comp'),
        )
        .alias('is_eclipsed')
    )
)

# remove pairs that were eclipsed by at least one other
principal_pairs = (
    comparison_results
    .group_by('a', 'b')
    .agg(pl.col('is_eclipsed').any())
    .filter(is_eclipsed=False)
    .select('a', 'b')
)

While this does appear to work, it is computationally infeasible for large datasets due to the sheer size of the self-joined table.

I have considered filtering the comparison_inputs table down by removing redundant reversed comparisons, e.g., pair X vs pair Y and pair Y vs pair X don't both need to be in the table as they currently are, but changing that requires an additional condition in each comparison to report which element was eclipsed in the comparison and only reduces the dataset in half, which isn't that significant.

I have found I can reduce the needed comparisons substantially by doing a window function filter that filters to only the max b for each a and vice versa before doing the self joining step. In other words:

unique_pairs = (
    pairs
    .unique()
    .filter(a = pl.col('a').last().over('b', order_by='a')
    .filter(b = pl.col('b').last().over('a', order_by='b')

But of course this only does so much and depends on the cardinality of a and b. I still need to self-join and compare after this to get a result.

I am curious if there is already some algorithm established for calculating this and whether anyone has ideas for a more efficient method. Interested to learn more anyway. Thanks in advance.

Upvotes: 4

Views: 82

Answers (2)

teepee
teepee

Reputation: 2714

Thanks to @Bhargav who gave the key in his answer, which I was able to adapt into the following (the main difference is that I used only a single variable to track the max of the second value rather than appending to and popping from a list):

Point = tuple[float, float]

def select_dominant_points(points: list[Point]) -> set[Point]:
    """
    Get the set of dominant points in a list of 2-D points,
    i.e., the "maxima of a point set"
    (see: https://en.wikipedia.org/wiki/Maxima_of_a_point_set)

    Dominant 2-D points can be found by sorting all points in descending order
    by the first element followed by the second and keeping only points whose
    second element exceeds the values of all other second elements before it.

    For example, given a set of points where each is denoted (a, b), when
    sorted in descending order, 'b_max' denotes the maximum second element of
    all points up to and including the current point. Each point associated wth
    an increase in b_max (including the first point) is a member of the maxima.

    point    b_max   dominant?
    -------- ------- ---------
     (9, 4)   4       YES
     (9, 2)   4       
     (8, 2)   4       
     (7, 3)   4       
     (6, 5)   5       YES
     (6, 3)   5       
     (4, 4)   5       
     (3, 8)   8       YES
     (3, 2)   8       
    """

    # reduce list of point pairs to a unique set
    unique_points = set(points)

    # sort points in reverse order (by first, then second element)
    points_in_descending_order = sorted(unique_points, reverse=True)

    # get set of dominant points (maxima) from the full set
    dominant_points = []
    b_max = 0
    for (a, b) in points_in_descending_order:
        if b > b_max:
            dominant_points.append((a, b))
            b_max = b

    return set(dominant_points)

Upvotes: 1

Bhargav
Bhargav

Reputation: 4251

What we can do from my perspective is. First, we remove duplicates and sort the pairs - First element in des order and with the ties in first element, sort by second element in des order

unique_pairs = sorted(set(pairs), reverse=True)

By keeping the condition for each pair If - b is greater than the maximum second element seen so far for all previous pairs with larger first elements, this pair cannot be eclipsed.

from typing import List, Tuple
import bisect

def find_non_eclipsed_pairs(pairs: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
    if not pairs:
        return []
    
    unique_pairs = sorted(set(pairs), reverse=True)
    
    result = []
    max_second_elements = []
    
    for pair in unique_pairs:
        if not max_second_elements or pair[1] > max_second_elements[-1]:
            result.append(pair)
            while max_second_elements and max_second_elements[-1] <= pair[1]:
                max_second_elements.pop()
            max_second_elements.append(pair[1])
            
    return sorted(result)

Testing

def test_pareto_pairs():
    test_cases = [
        (
            [(1,2), (1,5), (2,2), (1,2), (2,2), (9,1), (1,1)],
            [(1,5), (2,2), (9,1)]
        ),
        (
            [],
            []
        ),
        (
            [(1,1)],
            [(1,1)]
        ),
        (
            [(1,1), (2,2), (3,3), (4,4)],
            [(4,4)]
        ),
        (
            [(1,5), (5,1)],
            [(1,5), (5,1)]
        ),
        (
            [(1,1), (1,2), (2,1), (2,2), (3,1), (1,3)],
            [(1,3), (2,2), (3,1)]
        )
    ]
    
    for i, (input_pairs, expected) in enumerate(test_cases, 1):
        result = find_non_eclipsed_pairs(input_pairs)
        assert result == sorted(expected), f"Test case {i} failed: expected {expected}, got {result}"
        print(f"Test case {i} passed")

if __name__ == "__main__":
    test_pareto_pairs()
    
    pairs_list = [
        (1,2),
        (1,5),
        (2,2),
        (1,2),
        (2,2),
        (9,1),
        (1,1),
    ]
    
    result = find_non_eclipsed_pairs(pairs_list)
    print("\nOriginal pairs:", pairs_list)
    print("Non-eclipsed pairs:", result)

Which results

=================== RESTART: C:/Users/Bhargav/Desktop/test.py ==================
Test case 1 passed
Test case 2 passed
Test case 3 passed
Test case 4 passed
Test case 5 passed
Test case 6 passed

Original pairs: [(1, 2), (1, 5), (2, 2), (1, 2), (2, 2), (9, 1), (1, 1)]
Non-eclipsed pairs: [(1, 5), (2, 2), (9, 1)]

Time complexity - O(n log n) Space complexity is O(n)

Edit: Thanks for @no comment for suggesting using sort with reverse=True

unique_pairs = sorted(set(pairs), reverse=True)

Upvotes: 4

Related Questions