sgarza62
sgarza62

Reputation: 6238

How to choose the attribute of a user-defined class that will be evaluated for built-in min() and max() functions?

I want to use Python's built-in min() and max() functions on a collection of my Point objects. However, I want the distance attribute to be the compared value. How can I specify this in the class definition?

class Point():
    def __init__(self, x, y, distance):
        self.x = x
        self.y = y
        self.distance = distance

a = Point(3, 5, 2)
b = Point(5, 4, 1)
c = Point(8, 4, 5)

hopefully_b = min(a, b, c)

Upvotes: 0

Views: 176

Answers (2)

poke
poke

Reputation: 388023

When using min and max without a special key function, the functions will use the order of the objects to determine its result. So you will have to implement an order for your Point class, making e.g. a > b work.

To do that, you will need to implement a number of special methods. That way you enable comparison between those objects. Since you likely want a proper order where a < b implies b > a etc., you can use functools.total_ordering to implement a total order by just defining the equality and a single lower-than comparison:

import functools

@functools.total_ordering
class Point:
    def __init__ (self, x, y, distance):
        self.x = x
        self.y = y
        self.distance = distance

    def __lt__ (self, other):
        return (self.distance, self.x, self.y) < (other.distance, other.x, other.y)

    def __eq__ (self, other):
        return (self.distance, self.x, self.y) == (other.distance, other.x, other.y)

    def __repr__ (self):
        return 'Point({}, {}, {})'.format(self.x, self.y, self.distance)
>>> a, b, c = Point(3, 5, 2), Point(5, 4, 1), Point(8, 4, 5)
>>> min(a, b, c)
Point(5, 4, 1)
>>> max(a, b, c)
Point(8, 4, 5)

Upvotes: 2

Martijn Pieters
Martijn Pieters

Reputation: 1123420

You can use a key function:

from operator import attrgetter

max(a, b, c, key=attrgetter('distance'))

The operator.attrgetter() function produces a callable that returns the named attribute for each object passed to it. You could achieve the same with a lambda callable:

max(a, b, c, key=lambda p: p.distance)

Alternatively, you can add special methods to the class to define how they should be compared; the __eq__ method defines how two instances are equal, and methods like __lt__ and friends are used when comparing two instances to see how they order. max() will then make use of these methods to find a 'largest' item, without a key function.

With the @functools.total_ordering() decorator you only have to implement two of those; __eq__ and one of the comparison methods:

from functools import total_ordering

@total_ordering
class Point():
    def __init__(self, x, y, distance):
        self.x = x
        self.y = y
        self.distance = distance

    def __eq__(self, other):
        if not isinstance(other, type(self)):
            return NotImplemented  # only the same type or subclasses
        return (self.x, self.y, self.distance) == (other.x, other.y, other.distance)

    def __lt__(self, other):
        if not isinstance(other, type(self)):
            return NotImplemented  # only the same type or subclasses
        if (self.x, self.y) == (other.x, other.y):
            return False  # same coordinates makes them equalish?
        return self.distance < other.distance

For Point objects this should probably be thought out more; what happens if distance is not equal but x and y are, for example?

Upvotes: 6

Related Questions