snakile
snakile

Reputation: 54521

assertAlmostEqual in Python unit-test for collections of floats

The assertAlmostEqual(x, y) method in Python's unit testing framework tests whether x and y are approximately equal assuming they are floats.

The problem with assertAlmostEqual() is that it only works on floats. I'm looking for a method like assertAlmostEqual() which works on lists of floats, sets of floats, dictionaries of floats, tuples of floats, lists of tuples of floats, sets of lists of floats, etc.

For instance, let x = 0.1234567890, y = 0.1234567891. x and y are almost equal because they agree on each and every digit except for the last one. Therefore self.assertAlmostEqual(x, y) is True because assertAlmostEqual() works for floats.

I'm looking for a more generic assertAlmostEquals() which also evaluates the following calls to True:

Is there such a method or do I have to implement it myself?

Clarifications:

Upvotes: 116

Views: 112318

Answers (12)

Neil
Neil

Reputation: 49

Looking at this myself, I used the addTypeEqualityFunc method of the UnitTest library in combination with math.isclose.

Sample setup:

import math
from unittest import TestCase

class SomeFixtures(TestCase):
    @classmethod
    def float_comparer(cls, a, b, msg=None):
        if len(a) != len(b):
            raise cls.failureException(msg)
        if not all(map(lambda args: math.isclose(*args), zip(a, b))):
            raise cls.failureException(msg)

    def some_test(self):
        self.addTypeEqualityFunc(list, self.float_comparer)
        self.assertEqual([1.0, 2.0, 3.0], [1.0, 2.0, 3.0])

Upvotes: 3

Ewoud
Ewoud

Reputation: 131

You can also recursively call the already present unittest.assertAlmostEquals() and keep track of what element you are comparing, by adding a method to your unittest.

E.g. for lists of lists and list of tuples of floats:

def assertListAlmostEqual(self, first, second, delta=None, context=None):
    """Asserts lists of lists or tuples to check if they compare and 
       shows which element is wrong when comparing two lists
    """
    self.assertEqual(len(first), len(second), msg="List have different length")
    context = [first, second] if context is None else context
    for i in range(0, len(first)):
        if isinstance(first[0], tuple):
            context.append(i)
            self.assertListAlmostEqual(first[i], second[i], delta, context=context)
        if isinstance(first[0], list):
            context.append(i)
            self.assertListAlmostEqual(first[i], second[i], delta, context=context)
        elif isinstance(first[0], float):
            msg = "Difference in \n{} and \n{}\nFaulty element index={}".format(context[0], context[1], context[2:]+[i]) \
                if context is not None else None
            self.assertAlmostEqual(first[i], second[i], delta, msg=msg)

Outputs something like:

line 23, in assertListAlmostEqual
    self.assertAlmostEqual(first[i], second[i], delta, msg=msg)
AssertionError: 5.0 != 6.0 within 7 places (1.0 difference) : Difference in 
[(0.0, 5.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989090909092)] and 
[(0.0, 6.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989)]
Faulty element index=[0, 1]

Upvotes: 0

gnoodle
gnoodle

Reputation: 178

Use Pandas

Another way is to convert each of the two dicts etc into pandas dataframes and then use pd.testing.assert_frame_equal() to compare the two. I have used this successfully to compare lists of dicts.

Previous answers often don't work on structures involving dictionaries, but this one should. I haven't exhaustively tested this on highly nested structures, but imagine pandas would handle them correctly.

Example 1: compare two dicts

To illustrate this I will use your example data of a dict, since the other methods don't work with dicts. Your dict was:

x, y = 0.1234567890, 0.1234567891
{1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}

Then we can do:

pd.testing.assert_frame_equal(      
   pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index')   ,         
   pd.DataFrame.from_dict({1: y, 2: y, 3: y}, orient='index')   )

This doesn't raise an error, meaning that they are equal to a certain degree of precision.

However if we were to do

pd.testing.assert_frame_equal(      
   pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index')   ,         
   pd.DataFrame.from_dict({1: y, 2: y, 3: y + 1}, orient='index')   ) #add 1 to last value

then we are rewarded with the following informative message:

AssertionError: DataFrame.iloc[:, 0] (column name="0") are different

DataFrame.iloc[:, 0] (column name="0") values are different (33.33333 %)
[index]: [1, 2, 3]
[left]:  [0.123456789, 0.123456789, 0.123456789]
[right]: [0.1234567891, 0.1234567891, 1.1234567891]

For further details see pd.testing.assert_frame_equal documentation , particularly parameters check_exact, rtol, atol for info about how to specify required degree of precision either relative or actual.

Example 2: Nested dict of dicts

a = {i*10 : {1:1.1,2:2.1} for i in range(4)}
b = {i*10 : {1:1.1000001,2:2.100001} for i in range(4)}
# a = {0: {1: 1.1, 2: 2.1}, 10: {1: 1.1, 2: 2.1}, 20: {1: 1.1, 2: 2.1}, 30: {1: 1.1, 2: 2.1}}
# b = {0: {1: 1.1000001, 2: 2.100001}, 10: {1: 1.1000001, 2: 2.100001}, 20: {1: 1.1000001, 2: 2.100001}, 30: {1: 1.1000001, 2: 2.100001}}

and then do

pd.testing.assert_frame_equal(   pd.DataFrame(a), pd.DataFrame(b) )

- it doesn't raise an error: all values fairly similar. However, if we change a value e.g.

b[30][2] += 1
#  b = {0: {1: 1.1000001, 2: 2.1000001}, 10: {1: 1.1000001, 2: 2.1000001}, 20: {1: 1.1000001, 2: 2.1000001}, 30: {1: 1.1000001, 2: 3.1000001}}

and then run the same test, we get the following clear error message:

AssertionError: DataFrame.iloc[:, 3] (column name="30") are different

DataFrame.iloc[:, 3] (column name="30") values are different (50.0 %)
[index]: [1, 2]
[left]:  [1.1, 2.1]
[right]: [1.1000001, 3.1000001]

Upvotes: 3

Barney Szabolcs
Barney Szabolcs

Reputation: 12514

I would still use self.assertEqual() for it stays the most informative when shit hits the fan. You can do that by rounding, eg.

self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

where round_tuple is

def round_tuple(t: tuple, ndigits: int) -> tuple:
    return tuple(round(e, ndigits=ndigits) for e in t)

def round_list(l: list, ndigits: int) -> list:
    return [round(e, ndigits=ndigits) for e in l]

According to the python docs (see https://stackoverflow.com/a/41407651/1031191) you can get away with rounding issues like 13.94999999, because 13.94999999 == 13.95 is True.

Upvotes: 0

redsk
redsk

Reputation: 271

None of these answers work for me. The following code should work for python collections, classes, dataclasses, and namedtuples. I might have forgotten something, but so far this works for me.

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any


def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
    """
    Compares two objects by recursively walking them trough. Equality is as usual except for floats.
    Floats are compared according to the two measures defined below.

    :param o1: The first object.
    :param o2: The second object.
    :param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
    `abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
    :param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
    :return: Whether the two objects are almost equal.
    """
    if type(o1) != type(o2):
        return False

    composite_type_passed = False

    if hasattr(o1, '__slots__'):
        if len(o1.__slots__) != len(o2.__slots__):
            return False
        if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
                                    max_abs_ratio_diff, max_abs_diff)
            for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
            return False
        else:
            composite_type_passed = True

    if hasattr(o1, '__dict__'):
        if len(o1.__dict__) != len(o2.__dict__):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2))
            in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
            if not k1.startswith('__')):  # avoid infinite loops
            return False
        else:
            composite_type_passed = True

    if isinstance(o1, dict):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
            return False

    elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for v1, v2 in zip(o1, o2)):
            return False

    elif isinstance(o1, float):
        if o1 == o2:
            return True
        else:
            if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
                if o2 != 0:
                    if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
                        return False
                else:  # if both == 0, we already returned True
                    if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
                        return False
            if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
                return False
            return True

    else:
        if not composite_type_passed:
            return o1 == o2

    return True


class EqualityTest(unittest.TestCase):

    def test_floats(self) -> None:
        o1 = ('hi', 3, 3.4)
        o2 = ('hi', 3, 3.400001)
        self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
        self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))

    def test_ratio_only(self):
        o1 = ['hey', 10000, 123.12]
        o2 = ['hey', 10000, 123.80]
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))

    def test_diff_only(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 1234567890.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, 1))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))

    def test_both_ignored(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 0.80]
        o3 = ['hi', 10000, 0.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, -1))
        self.assertFalse(are_almost_equal(o1, o3, -1, -1))

    def test_different_lengths(self):
        o1 = ['hey', 1234567890.12, 10000]
        o2 = ['hey', 1234567890.80]
        self.assertFalse(are_almost_equal(o1, o2, 1, 1))

    def test_classes(self):
        class A:
            d = 12.3

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
        self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))

        o2.hello = 'hello'
        self.assertFalse(are_almost_equal(o1, o2, -1, -1))

    def test_namedtuples(self):
        B = namedtuple('B', ['x', 'y'])
        o1 = B(3.3, 4.4)
        o2 = B(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))

    def test_classes_with_slots(self):
        class C(object):
            __slots__ = ['a', 'b']

            def __init__(self, a, b):
                self.a = a
                self.b = b

        o1 = C(3.3, 4.4)
        o2 = C(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))

    def test_dataclasses(self):
        @dataclass
        class D:
            s: str
            i: int
            f: float

        @dataclass
        class E:
            f2: float
            f4: str
            d: D

        o1 = E(12.3, 'hi', D('hello', 34, 20.01))
        o2 = E(12.1, 'hi', D('hello', 34, 20.0))
        self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))

        o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
        self.assertFalse(are_almost_equal(o2, o3, -1, -1))

    def test_ordereddict(self):
        o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))

Upvotes: 2

Karl Rosaen
Karl Rosaen

Reputation: 4644

An alternative approach is to convert your data into a comparable form by e.g turning each float into a string with fixed precision.

def comparable(data):
    """Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
    if isinstance(data, (int, str)):
        return data
    if isinstance(data, float):
        return '{:.4f}'.format(data)
    if isinstance(data, list):
        return [comparable(el) for el in data]
    if isinstance(data, tuple):
        return tuple([comparable(el) for el in data])
    if isinstance(data, dict):
        return {k: comparable(v) for k, v in data.items()}

Then you can:

self.assertEquals(comparable(value1), comparable(value2))

Upvotes: 0

As of python 3.5 you may compare using

math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)

As described in pep-0485. The implementation should be equivalent to

abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )

Upvotes: 14

DJCowley
DJCowley

Reputation: 83

If you don't mind using the numpy package then numpy.testing has the assert_array_almost_equal method.

This works for array_like objects, so it is fine for arrays, lists and tuples of floats, but does it not work for sets and dictionaries.

The documentation is here.

Upvotes: 6

Pierre GM
Pierre GM

Reputation: 20339

if you don't mind using NumPy (which comes with your Python(x,y)), you may want to look at the np.testing module which defines, among others, a assert_almost_equal function.

The signature is np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError: 
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)

Upvotes: 96

snakile
snakile

Reputation: 54521

Here's how I've implemented a generic is_almost_equal(first, second) function:

First, duplicate the objects you need to compare (first and second), but don't make an exact copy: cut the insignificant decimal digits of any float you encounter inside the object.

Now that you have copies of first and second for which the insignificant decimal digits are gone, just compare first and second using the == operator.

Let's assume we have a cut_insignificant_digits_recursively(obj, places) function which duplicates obj but leaves only the places most significant decimal digits of each float in the original obj. Here's a working implementation of is_almost_equals(first, second, places):

from insignificant_digit_cutter import cut_insignificant_digits_recursively

def is_almost_equal(first, second, places):
    '''returns True if first and second equal. 
    returns true if first and second aren't equal but have exactly the same
    structure and values except for a bunch of floats which are just almost
    equal (floats are almost equal if they're equal when we consider only the
    [places] most significant digits of each).'''
    if first == second: return True
    cut_first = cut_insignificant_digits_recursively(first, places)
    cut_second = cut_insignificant_digits_recursively(second, places)
    return cut_first == cut_second

And here's a working implementation of cut_insignificant_digits_recursively(obj, places):

def cut_insignificant_digits(number, places):
    '''cut the least significant decimal digits of a number, 
    leave only [places] decimal digits'''
    if  type(number) != float: return number
    number_as_str = str(number)
    end_of_number = number_as_str.find('.')+places+1
    if end_of_number > len(number_as_str): return number
    return float(number_as_str[:end_of_number])

def cut_insignificant_digits_lazy(iterable, places):
    for obj in iterable:
        yield cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits_recursively(obj, places):
    '''return a copy of obj except that every float loses its least significant 
    decimal digits remaining only [places] decimal digits'''
    t = type(obj)
    if t == float: return cut_insignificant_digits(obj, places)
    if t in (list, tuple, set):
        return t(cut_insignificant_digits_lazy(obj, places))
    if t == dict:
        return {cut_insignificant_digits_recursively(key, places):
                cut_insignificant_digits_recursively(val, places)
                for key,val in obj.items()}
    return obj

The code and its unit tests are available here: https://github.com/snakile/approximate_comparator. I welcome any improvement and bug fix.

Upvotes: 9

Samy Vilar
Samy Vilar

Reputation: 11130

You may have to implement it yourself, while its true that list and sets can be iterated the same way, dictionaries are a different story, you iterate their keys not values, and the third example seems a bit ambiguous to me, do you mean to compare each value within the set, or each value from each set.

heres a simple code snippet.

def almost_equal(value_1, value_2, accuracy = 10**-8):
    return abs(value_1 - value_2) < accuracy

x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))

Upvotes: 2

BrenBarn
BrenBarn

Reputation: 251408

There is no such method, you'd have to do it yourself.

For lists and tuples the definition is obvious, but note that the other cases you mention aren't obvious, so it's no wonder such a function isn't provided. For instance, is {1.00001: 1.00002} almost equal to {1.00002: 1.00001}? Handling such cases requires making a choice about whether closeness depends on keys or values or both. For sets you are unlikely to find a meaningful definition, since sets are unordered, so there is no notion of "corresponding" elements.

Upvotes: 4

Related Questions