max
max

Reputation: 52323

Check if all elements in a list are equal

I need a function which takes in a list and outputs True if all elements in the input list evaluate as equal to each other using the standard equality operator and False otherwise.

I feel it would be best to iterate through the list comparing adjacent elements and then AND all the resulting Boolean values. But I'm not sure what's the most Pythonic way to do that.

Upvotes: 583

Views: 663494

Answers (30)

Lajos
Lajos

Reputation: 2827

I ended up with this one-liner

from operator import eq
from itertools import starmap, pairwise
all(starmap(eq, pairwise(x)))

Upvotes: 1

ninjagecko
ninjagecko

Reputation: 91132

edit: Everyone seems hung up on speed/performance for some insane reason, when the original poster clearly asked about the cleanest solution.

This answer provides both

  1. a clean solution, as well as
  2. the fastest solution (faster than the currently top-voted itertools.groupby answer; see addendum at end).

Without rewriting the program, the most asymptotically performant and most readable way is as follows:

all(x==myList[0] for x in myList)

(Yes, this even works with the empty list! This is because this is one of the few cases where python has lazy semantics.)

This will fail at the earliest possible time, so it is asymptotically optimal (expected time is approximately O(#uniques) rather than O(N), but worst-case time still O(N)). This is assuming you have not seen the data before...

(If you care about performance but not that much about performance, you can just do the usual standard optimizations first, like hoisting the myList[0] constant out of the loop and adding clunky logic for the edge case, though this is something the python compiler might eventually learn how to do and thus one should not do it unless absolutely necessary, as it destroys readability for minimal gain.)

If you care slightly more about performance, this is twice as fast as above but a bit more verbose:

def allEqual(iterable):
    iterator = iter(iterable)
    
    try:
        firstItem = next(iterator)
    except StopIteration:
        return True
        
    for x in iterator:
        if x!=firstItem:
            return False
    return True

If you care even more about performance (but not enough to rewrite your program), use the currently top-voted itertools.groupby answer, which is twice as fast as allEqual because it is probably optimized C code. (According to the docs, it should (similar to this answer) not have any memory overhead because the lazy generator is never evaluated into a list... which one might be worried about, but the pseudocode shows that the grouped 'lists' are actually lazy generators.)

If you care even more about performance read on...


sidenotes regarding performance, because the other answers are talking about it for some unknown reason:

Heuristic tricks:

  • (if it's a sequence with getitem) If you suspect the first or last element may be different, you can check those in a heuristic manner, and if they're the same, then proceed with the above algorithm as originally intended

Fastest performance:

... if you have seen the data before and are likely using a collection data structure of some sort, you can get .isAllEqual() for free O(1) by augmenting your structure with a Counter that is updated with every insert/delete/etc. operation. Then .isAllEqual() is merely seeing if the Counter is basically == to {something:someCount} i.e. len(counter)==1 (though we also need to check if it's the empty list [], for which the statement "all elements are equal" is vacuously true, so we should return len(counter)<=1). This presumes your list elements are hashable.

See https://docs.python.org/3/reference/datamodel.html#emulating-container-types

from collections import Counter

def asList(x):
    return x if isinstance(x,list) else [x]

class CounterList(list):  # UNTESTED, may be missing methods
    def __init__(self, iterable):
        super().__init__(iterable)
        self.counts = Counter(iterable)
    
    def __setitem__(self, target, itemOrItems):  # works with slices!
        toRemove = asList(self[target])
        toAdd = asList(itemOrItems)
        super().__setitem__(target, itemOrItems)
        for x in toRemove:
            self.counts[x] -= 1
            if self.counts[x]==0:
                del self.counts[x]
        for y in toAdd:
            self.counts[y] += 1
    
    def __delitem__(self, target):
        toRemove = asList(self[target])
        super().__delitem__(target, itemOrItems)
        for x in toRemove:
            self.counts[x] -= 1
            if self.counts[x]==0:
                del self.counts[x]
    
    def pop(self):
        R = super().pop
        self.counts[R] -= 1
        return R
    
    def remove(self, item):
        super().remove(item)  # throws ValueError if missing
        self.counts[item] -= 1
    
    def insert(self, index, item):
        super().insert(index, item)
        self.counts[item] += 1
    
    def append(self, item):
        super().append(item)
        self.counts[item] += 1
    
    def extend(self, otherList):
        super().extend(otherList)
        for item in otherList:
            self.counts[item] += 1
    
    
    def isAllEqual(self):
        return len(self.counts)<=1

Demo:

demo = CounterList([5,6])
demo[0] = 6
demo.isAllEqual()  # outputs True

Alternatively you can keep a Counter on the side in a separate variable, and update it whenever you would modify myList. (Specifically, manually copy the code above in CounterList whenever you do an operation.)

    myList = [.....]
    myListCounter = Counter()
    
    
    myList.append(5)
    myListCounter[5] += 1  # extra bookkeeping
    
    myListCounter[myList[2]] -= 1  # extra bookkeeping
    myList[2] = True
    myListCounter[True] += 1  # extra bookkeeping

Either of these two methods are equivalent, and provably asymptotically better than anything else. tl;dr You can always avoid going through the whole list to check all(x==myList[0] for x in myList); to do this, whenever you would modify myList, you also keep track with a Counter how many of each element there are.

Of course, there's something to be said for readability.

Upvotes: 245

dlitz
dlitz

Reputation: 6169

You can also use the all_equal function from the handy more_itertools package.

pip install more_itertools
>>> from more_itertools import all_equal
>>> all_equal((2, 2, 2, 1 + 1))
True
>>> all_equal((2, 2, 2, 0))
False

Upvotes: 0

Daniel Aviv
Daniel Aviv

Reputation: 45

l1 = [1, 2, 3, 4, 5]
l2 = [1] * 5
l3 = []
all_equal = lambda l: len(l) > 0 and all(l[0] == e for e in l)
print(f"{all_equal(l1) = }")  # False
print(f"{all_equal(l2) = }")  # True
print(f"{all_equal(l3) = }")  # False

Upvotes: -1

user2138149
user2138149

Reputation: 17230

tl;dr:

all((x==yourList[0] for x in yourList))

Which is a more Pythonic version of:

from functools import reduce
from operator import and_

reduce(and_, (x==yourList[0] for x in yourList), True)

The clever thing about this is that it uses lazy-evaulated generator iteration rather than list comprehension.

You can see this by experimenting with the following code:

print(type((x == yourList[0] for x in yourList)))
# generator

print(type([x == yourList[0] for x in yourList]))
# list

It is fairly annoying that python makes you import the operators like operator.and_. As of python3, you will need to also import functools.reduce.

(You should not use the method which uses the reduce and and_ operator because it will not break if it finds non-equal values, but will continue examining the entire list. On the other hand, using and will break execution as soon as a False value is obtained.)

For those reasons, the more Pythonic version is preferred:

all((x == yourList[0] for x in yourList))

Upvotes: -2

ninjagecko
ninjagecko

Reputation: 91132

You can do:

reduce(and_, (x==yourList[0] for x in yourList), True)

It is fairly annoying that python makes you import the operators like operator.and_. As of python3, you will need to also import functools.reduce.

(You should not use this method because it will not break if it finds non-equal values, but will continue examining the entire list. It is just included here as an answer for completeness.)

Upvotes: 1

Stefan Pochmann
Stefan Pochmann

Reputation: 28636

They're all equal if there aren't two groups.

from itertools import groupby, pairwise

def all_equal(iterable):
    return not any(pairwise(groupby(iterable)))

Attempt This Online!

Upvotes: 0

user2138149
user2138149

Reputation: 17230

A functional solution, which results in a clean code:

Advantages:

  • Lazy evaluated
  • No side effects, does not depend on side effects from enclosing scope (eg does not pull random values like input[0] out of the enclosing scope - this is fragile for obvious reasons)
  • Only requires operator.eq to be defined for input types. Does not require __hash__
  • Functional style (no for loops!)
  • It works for edge cases, for example empty list or list with a single element
input_list [1, 1, 1, 1, 1, 1, 1]

import operator
from itertools import pairwise
from itertools import starmap

def all_items_in_list_are_equal(l: list) -> bool:

    all_items_are_equal = \
        all(
            starmap(
                operator.eq,
                pairwise(l)
            )
        )

    return all_items_are_equal

It works by producing a pairwise iterator pairwise(l) which returns pairs of adjacent elements in the list l as a tuple.

The tuple is expanded into pairs of arguments by starmap.

  • starmap is to map what function(*args) is to function(a, b, c, ...)

operator.eq expects two arguments, so a single argument of type pair will not do. We use starmap to apply operator.eq to every pair of values yield by the iterator, effectively by expanding the tuple into a pair of individual arguments.

Finally, all effectively applies a lambda function to check that all elements are True.

Note: This answer came from a separate question I asked about achieving the same objective as in the original question here, but in a functional style. I asked this question after reading through the entire history of answers here.

Most of the other answers have problems, for example:

  • edge cases
  • poor performance due to eager evaluation
  • irregular dependencies, in other words:
  • fragile code which is likely to break due to different lines of code depending on each other in weird ways
  • the use of for loops, which of course are evil

There is another solution which avoids the overhead of starmap, pairwise at the cost of maintaining and iterating over a pair of iterators.

The problem with this is it involves index manipulation which is an easy way to produce unmaintainable code, and edge cases/bugs. (eg: Will this work with a list of length 0?)

It also isn't especially readable because it forces you to think about index operations.

input_list [1, 1, 1, 1, 1, 1, 1]

import operator

def all_items_in_list_are_equal(l: list) -> bool:

    all_items_are_equal = \
        all(
            map(
                operator.eq,
                l[1:],
                l[:-1]
            )
        )

    return all_items_are_equal

If you were wondering if building a custom iterator provides a good solution, the answer is unfortunately "no" because the resulting code is pretty slow. (Even taking into account an optimization which removes the need for a branch statement in the __next__ function.)

class EqualPairwiseIterator():

    def __init__(self, iterable):
        self.iterable = iterable
        self.iterator = iter(iterable)
        self.previous_value = None
        if len(iterable) < 1:
            raise ValueError(f'EqualPairwiseIterator does not work with lists of length 0')

    def __iter__(self):
        try:
            self.previous_value = next(self.iterator) # will raise StopException when out of range
        except StopException as stop_exception:
            raise ValueError(f'EqualPairwiseIterator does not work with lists of length 0')
        return self

    def __next__(self):
        
        next_value = next(self.iterator) # will raise StopException when out of range
        op_value = operator.eq(next_value, self.previous_value)
        self.previous_value = next_value
        return op_value 

# useage:
iterable = [...]
all(EqualPairwiseIterator(iterable))

Upvotes: -2

6502
6502

Reputation: 114579

This is another option, faster than len(set(x))==1 for long lists (uses short circuit)

def constantList(x):
    return x[:1]*len(x) == x

Upvotes: 12

Robert Rossney
Robert Rossney

Reputation: 96850

I'd do:

all((x[i] == x[i+1] for i in range(0, len(x)-1)))

as all stops searching the iterable as soon as it finds a True condition.

Upvotes: 2

Luca Di Liello
Luca Di Liello

Reputation: 1643

I suggest a simple pythonic solution:

def all_equal_in_iterable(iterable: Iterable):
    iterable = list(iterable)
    if not iterable:
        return True
    return all(item == iterable[0] for item in iterable)

Upvotes: 0

Stefan Pochmann
Stefan Pochmann

Reputation: 28636

More versions using itertools.groupby that I find clearer than the original (more about that below):

def all_equal(iterable):
    g = groupby(iterable)
    return not any(g) or not any(g)

def all_equal(iterable):
    g = groupby(iterable)
    next(g, None)
    return not next(g, False)

def all_equal(iterable):
    g = groupby(iterable)
    return not next(g, False) or not next(g, False)

Here's the original from the Itertools Recipes again:

def all_equal(iterable):
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

Note that the next(g, True) is always true (it's either a non-empty tuple or True). That means its value doesn't matter. It's executed purely for advancing the groupby iterator. But including it in the return expression leads the reader into thinking that its value gets used there. Since it doesn't, I find that misleading and unnecessarily complicated. My second version above treats the next(g, True) as what it's actually used for, as a statement whose value we don't care about.

My third version goes a different direction and does use the value of the first next(g, False). If there isn't even a first group at all (i.e., if the given iterable is "empty"), then that solution returns the result right away and doesn't even check whether there's a second group.

My first solution is basically the same as my third, just using any. Both solutions read as "All elements are equal iff ... there is no first group or there is no second group."

Benchmark results (although speed is really not my point here, clarity is, and in practice if there are many equal values, most of the time might be spent by the groupby itself, reducing the impact of these differences here):

Python 3.10.4 on my Windows laptop:

iterable = ()
 914 ns   914 ns   916 ns  use_first_any
 917 ns   925 ns   925 ns  use_first_next
1074 ns  1075 ns  1075 ns  next_as_statement
1081 ns  1083 ns  1084 ns  original

iterable = (1,)
1290 ns  1290 ns  1291 ns  next_as_statement
1303 ns  1307 ns  1307 ns  use_first_next
1306 ns  1307 ns  1309 ns  use_first_any
1318 ns  1319 ns  1320 ns  original

iterable = (1, 2)
1463 ns  1464 ns  1467 ns  use_first_any
1463 ns  1463 ns  1467 ns  next_as_statement
1477 ns  1479 ns  1481 ns  use_first_next
1487 ns  1489 ns  1492 ns  original
Python 3.10.4 on a Debian Google Compute Engine instance:

iterable = ()
 234 ns   234 ns   234 ns  use_first_any
 234 ns   235 ns   235 ns  use_first_next
 264 ns   264 ns   264 ns  next_as_statement
 265 ns   265 ns   265 ns  original

iterable = (1,)
 308 ns   308 ns   308 ns  next_as_statement
 315 ns   315 ns   315 ns  original
 316 ns   316 ns   317 ns  use_first_any
 317 ns   317 ns   317 ns  use_first_next

iterable = (1, 2)
 361 ns   361 ns   361 ns  next_as_statement
 367 ns   367 ns   367 ns  original
 384 ns   385 ns   385 ns  use_first_next
 386 ns   387 ns   387 ns  use_first_any

Benchmark code:

from timeit import timeit
from random import shuffle
from bisect import insort
from itertools import groupby

def original(iterable):
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

def use_first_any(iterable):
    g = groupby(iterable)
    return not any(g) or not any(g)

def next_as_statement(iterable):
    g = groupby(iterable)
    next(g, None)
    return not next(g, False)

def use_first_next(iterable):
    g = groupby(iterable)
    return not next(g, False) or not next(g, False)

funcs = [original, use_first_any, next_as_statement, use_first_next]

for iterable in (), (1,), (1, 2):
    print(f'{iterable = }')
    times = {func: [] for func in funcs}
    for _ in range(1000):
        shuffle(funcs)
        for func in funcs:
            number = 1000
            t = timeit(lambda: func(iterable), number=number) / number
            insort(times[func], t)
    for func in sorted(funcs, key=times.get):
        print(*('%4d ns ' % round(t * 1e9) for t in times[func][:3]), func.__name__)
    print()

Upvotes: 1

Raymond Hettinger
Raymond Hettinger

Reputation: 226574

Best Answer

There was a nice Twitter thread on the various ways to implement an all_equal() function.

Given a list input, the best submission was:

 t.count(t[0]) == len(t)  

Other Approaches

Here is are the results from the thread:

  1. Have groupby() compare adjacent entries. This has an early-out for a mismatch, does not use extra memory, and it runs at C speed.

    g = itertools.groupby(s)
    next(g, True) and not next(g, False)
    
  2. Compare two slices offset from one another by one position. This uses extra memory but runs at C speed.

    s[1:] == s[:-1]
    
  3. Iterator version of slice comparison. It runs at C speed and does not use extra memory; however, the eq calls are expensive.

    all(map(operator.eq, s, itertools.islice(s, 1, None)))
    
  4. Compare the lowest and highest values. This runs at C speed, doesn't use extra memory, but does cost two inequality tests per datum.

    min(s) == max(s)  # s must be non-empty
    
  5. Build a set. This runs at C speed and uses little extra memory but requires hashability and does not have an early-out.

    len(set(t))==1.
    
  6. At great cost, this handles NaNs and other objects with exotic equality relations.

    all(itertools.starmap(eq, itertools.product(s, repeat=2)))
    
  7. Pull out the first element and compare all the others to it, stopping at the first mismatch. Only disadvantage is that this doesn't run at C speed.

     it = iter(s)
     a = next(it, None)
     return all(a == b for b in it)
    
  8. Just count the first element. This is fast, simple, elegant. It runs at C speed, requires no additional memory, uses only equality tests, and makes only a single pass over the data.

      t.count(t[0]) == len(t)
    

Upvotes: 4

mykhal
mykhal

Reputation: 19924

Here is a code with good amount of Pythonicity, and balance of simplicity and obviousness, I think, which should work also in pretty old Python versions.

def all_eq(lst):
    for idx, itm in enumerate(lst):
        if not idx:   # == 0
            prev = itm
        if itm != prev:
            return False
        prev = itm
    return True

Upvotes: -2

Kermit
Kermit

Reputation: 6022

Maybe I'm underestimating the problem? Check the length of unique values in the list.

lzt = [1,1,1,1,1,2]

if (len(set(lzt)) > 1):
    uniform = False
elif (len(set(lzt)) == 1):
    uniform = True
elif (not lzt):
    raise ValueError("List empty, get wrecked")

Upvotes: -1

Ivo van der Wijk
Ivo van der Wijk

Reputation: 16785

A solution faster than using set() that works on sequences (not iterables) is to simply count the first element. This assumes the list is non-empty (but that's trivial to check, and decide yourself what the outcome should be on an empty list)

x.count(x[0]) == len(x)

some simple benchmarks:

>>> timeit.timeit('len(set(s1))<=1', 's1=[1]*5000', number=10000)
1.4383411407470703
>>> timeit.timeit('len(set(s1))<=1', 's1=[1]*4999+[2]', number=10000)
1.4765670299530029
>>> timeit.timeit('s1.count(s1[0])==len(s1)', 's1=[1]*5000', number=10000)
0.26274609565734863
>>> timeit.timeit('s1.count(s1[0])==len(s1)', 's1=[1]*4999+[2]', number=10000)
0.25654196739196777

Upvotes: 398

cbalawat
cbalawat

Reputation: 1131

Convert your input into a set:

len(set(the_list)) <= 1

Using set removes all duplicate elements. <= 1 is so that it correctly returns True when the input is empty.

This requires that all the elements in your input are hashable. You'll get a TypeError if you pass in a list of lists for example.

Upvotes: 75

kennytm
kennytm

Reputation: 523584

Use itertools.groupby (see the itertools recipes):

from itertools import groupby

def all_equal(iterable):
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

or without groupby:

def all_equal(iterator):
    iterator = iter(iterator)
    try:
        first = next(iterator)
    except StopIteration:
        return True
    return all(first == x for x in iterator)

There are a number of alternative one-liners you might consider:

  1. Converting the input to a set and checking that it only has one or zero (in case the input is empty) items

    def all_equal2(iterator):
        return len(set(iterator)) <= 1
    
  2. Comparing against the input list without the first item

    def all_equal3(lst):
        return lst[:-1] == lst[1:]
    
  3. Counting how many times the first item appears in the list

    def all_equal_ivo(lst):
        return not lst or lst.count(lst[0]) == len(lst)
    
  4. Comparing against a list of the first element repeated

    def all_equal_6502(lst):
        return not lst or [lst[0]]*len(lst) == lst
    

But they have some downsides, namely:

  1. all_equal and all_equal2 can use any iterators, but the others must take a sequence input, typically concrete containers like a list or tuple.
  2. all_equal and all_equal3 stop as soon as a difference is found (what is called "short circuit"), whereas all the alternatives require iterating over the entire list, even if you can tell that the answer is False just by looking at the first two elements.
  3. In all_equal2 the content must be hashable. A list of lists will raise a TypeError for example.
  4. all_equal2 (in the worst case) and all_equal_6502 create a copy of the list, meaning you need to use double the memory.

On Python 3.9, using perfplot, we get these timings (lower Runtime [s] is better):

for a list with a difference in the first two elements, groupby is fastestfor a list with no differences, count(l[0]) is fastest

Upvotes: 633

codaddict
codaddict

Reputation: 455342

You can convert the list to a set. A set cannot have duplicates. So if all the elements in the original list are identical, the set will have just one element.

if len(set(input_list)) == 1:
    # input_list has all identical elements.

Upvotes: 50

Foad S. Farimani
Foad S. Farimani

Reputation: 14016

There is also a pure Python recursive option:

def checkEqual(lst):
    if len(lst)==2 :
        return lst[0]==lst[1]
    else:
        return lst[0]==lst[1] and checkEqual(lst[1:])

However for some reason it is in some cases two orders of magnitude slower than other options. Coming from C language mentality, I expected this to be faster, but it is not!

The other disadvantage is that there is recursion limit in Python which needs to be adjusted in this case. For example using this.

Upvotes: -1

Marcus Lind
Marcus Lind

Reputation: 11460

Regarding using reduce() with lambda. Here is a working code that I personally think is way nicer than some of the other answers.

reduce(lambda x, y: (x[1]==y, y), [2, 2, 2], (True, 2))

Returns a tuple where the first value is the boolean if all items are same or not.

Upvotes: 1

Saeed
Saeed

Reputation: 2099

You can use .nunique() to find number of unique items in a list.

def identical_elements(list):
    series = pd.Series(list)
    if series.nunique() == 1: identical = True
    else:  identical = False
    return identical



identical_elements(['a', 'a'])
Out[427]: True

identical_elements(['a', 'b'])
Out[428]: False

Upvotes: -1

Luis B
Luis B

Reputation: 1722

Or use diff method of numpy:

import numpy as np
def allthesame(l):
    return np.unique(l).shape[0]<=1

And to call:

print(allthesame([1,1,1]))

Output:

True

Upvotes: 0

U13-Forward
U13-Forward

Reputation: 71610

Or use diff method of numpy:

import numpy as np
def allthesame(l):
    return np.all(np.diff(l)==0)

And to call:

print(allthesame([1,1,1]))

Output:

True

Upvotes: 1

SuperNova
SuperNova

Reputation: 27486

Can use map and lambda

lst = [1,1,1,1,1,1,1,1,1]

print all(map(lambda x: x == lst[0], lst[1:]))

Upvotes: 0

Gusev Slava
Gusev Slava

Reputation: 2204

Check if all elements equal to the first.

np.allclose(array, array[0])

Upvotes: 5

mgilson
mgilson

Reputation: 310099

For what it's worth, this came up on the python-ideas mailing list recently. It turns out that there is an itertools recipe for doing this already:1

def all_equal(iterable):
    "Returns True if all the elements are equal to each other"
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

Supposedly it performs very nicely and has a few nice properties.

  1. Short-circuits: It will stop consuming items from the iterable as soon as it finds the first non-equal item.
  2. Doesn't require items to be hashable.
  3. It is lazy and only requires O(1) additional memory to do the check.

1In other words, I can't take the credit for coming up with the solution -- nor can I take credit for even finding it.

Upvotes: 25

user3015260
user3015260

Reputation: 9

lambda lst: reduce(lambda a,b:(b,b==a[0] and a[1]), lst, (lst[0], True))[1]

The next one will short short circuit:

all(itertools.imap(lambda i:yourlist[i]==yourlist[i+1], xrange(len(yourlist)-1)))

Upvotes: 0

Joshua Burns
Joshua Burns

Reputation: 8582

If you're interested in something a little more readable (but of course not as efficient,) you could try:

def compare_lists(list1, list2):
    if len(list1) != len(list2): # Weed out unequal length lists.
        return False
    for item in list1:
        if item not in list2:
            return False
    return True

a_list_1 = ['apple', 'orange', 'grape', 'pear']
a_list_2 = ['pear', 'orange', 'grape', 'apple']

b_list_1 = ['apple', 'orange', 'grape', 'pear']
b_list_2 = ['apple', 'orange', 'banana', 'pear']

c_list_1 = ['apple', 'orange', 'grape']
c_list_2 = ['grape', 'orange']

print compare_lists(a_list_1, a_list_2) # Returns True
print compare_lists(b_list_1, b_list_2) # Returns False
print compare_lists(c_list_1, c_list_2) # Returns False

Upvotes: -1

itertool
itertool

Reputation: 29

def allTheSame(i):
    j = itertools.groupby(i)
    for k in j: break
    for k in j: return False
    return True

Works in Python 2.4, which doesn't have "all".

Upvotes: 2

Related Questions