PhE
PhE

Reputation: 16644

Cache object instances with lru_cache and __hash__

I don't undestand how functools.lru_cache works with object instances. I assume the class has to provide a __hash__ method. So any instance with the same hash should hit the cache.

Here is my test :

from functools import lru_cache

class Query:    
    def __init__(self, id: str):
        self.id = id

    def __hash__(self):
        return hash(self.id)

@lru_cache()
def fetch_item(item):
    return 'data'

o1 = Query(33)
o2 = Query(33)
o3 = 33

assert hash(o1) == hash(o2) == hash(o3)

fetch_item(o1)  # <-- expecting miss
fetch_item(o1)  # <-- expecting hit
fetch_item(o2)  # <-- expecting hit BUT get a miss !
fetch_item(o3)  # <-- expecting hit BUT get a miss !
fetch_item(o3)  # <-- expecting hit

info = fetch_item.cache_info()
print(info)

assert info.hits == 4
assert info.misses == 1
assert info.currsize == 1

How to cache calls of object instances with the same hash ?

Upvotes: 3

Views: 6160

Answers (2)

sj95126
sj95126

Reputation: 6898

Short answer: in order to get a cache hit on o2 when o1 is already in the cache, the class can define a __eq__() method, to compare whether Query objects have equal value.

For example:

def __eq__(self, other):
    return isinstance(other, Query) and self.id == other.id

Update: one additional detail worth mentioning in the summary rather than being buried in the details: the behavior described here also applies to the functools.cache wrapper introduced in Python 3.9, as @cache() is simply a shortcut for @lru_cache(maxsize=None).

Long answer (including o3):

There's a good explanation here about the exact mechanism for dictionary lookups, so I won't recreate it all. Sufficed to say, since the LRU cache is stored as a dict, class objects need to compare as equal to be considered as existing in the cache already, because of the way dictionary keys are compared.

You can see this in a quick example with a regular dictionary, with two versions of the class where one uses __eq__() and the other doesn't:

>>> o1 = Query_with_eq(33)
>>> o2 = Query_with_eq(33)
>>> {o1: 1, o2: 2}
{<__main__.Query_with_eq object at 0x6fffffea9430>: 2}

which results in one item in the dictionary, because the keys are equal, whereas:

>>> o1 = Query_without_eq(33)
>>> o2 = Query_without_eq(33)
>>> {o1: 1, o2: 2}
{<__main__.Query_without_eq object at 0x6fffffea9cd0>: 1, <__main__.Query_without_eq object at 0x6fffffea9c70>: 2}

results in two items (unequal keys).

Why int doesn't result in a cache hit when a Query object exists:

o3 is a regular int object. While its value does compare equal to Query(33), assuming Query.__eq__() compares types properly, lru_cache has an optimization that bypasses that comparison.

Normally, lru_cache creates a dictionary key (as a tuple) of the arguments to the wrapped function. Optionally, if the cache was created with the typed=True argument, it also stores the type of each argument, so that values are only equal if they are also of the same type.

The optimization is that if there is only one argument to the wrapped function, and it's type int or str, the single argument is used directly as the dictionary key, instead of being turned into a tuple. Therefore, (Query(33),) and 33 don't compare as equal, even though effectively they store the same value. (Note that I'm not saying that int objects aren't cached, only that they don't match an existing value of a non-int type. From your example, you can see that fetch_item(o3) gets a cache hit on the second call).

You can get cache hits if the argument was a different type than int. For example, 33.0 would match, again presuming that Query.__eq__() takes types into account and returns True. For that you could do something like:

def __eq__(self, other):
    if isinstance(other, Query):
        return self.id == other.id
    else:
        return self.id == other

Upvotes: 9

Ashwini Chaudhary
Ashwini Chaudhary

Reputation: 250871

Even though lru_cache() expects its arguments to be hashable, it doesn't use their actual hash values, hence you're getting those misses.

The function _make_key makes use of _HashedSeq to make sure all the items it has are hashable, but later on in _lru_cache_wrapper it doesn't use the hash value.

(_HashedSeq is skipped if there's only one argument and it is of int or str type)

class _HashedSeq(list):
    """ This class guarantees that hash() will be called no more than once
        per element.  This is important because the lru_cache() will hash
        the key multiple times on a cache miss.
    """

    __slots__ = 'hashvalue'

    def __init__(self, tup, hash=hash):
        self[:] = tup
        self.hashvalue = hash(tup)

    def __hash__(self):
        return self.hashvalue
fetch_item(o1)  # Stores (o1,) in cache dictionary, but misses and stores (o1,)
fetch_item(o1)  # Finds (o1,) in cache dictionary
fetch_item(o2)  # Looks for (o2,) in cache dictionary, but misses and stores (o2,)
fetch_item(o3)  # Looks for (o3,) in cache dictionary, but misses and stores (33,)

Unfortunately, there isn't a documented way to provide custom make_key function, so, one way to achieve this is by monkey-patching the _make_key function(within a context manager):

import functools
from contextlib import contextmanager


def make_key(*args, **kwargs):
    return hash(args[0][0])


def fetch_item(item):
    return 'data'

@contextmanager
def lru_cached_fetch_item():
    try:
        _make_key_og = functools._make_key
        functools._make_key = make_key
        yield functools.lru_cache()(fetch_item)
    finally:
        functools._make_key = _make_key_og


class Query:    
    def __init__(self, id: int):
        self.id = id

    def __hash__(self):
        return hash(self.id)


o1 = Query(33)
o2 = Query(33)
o3 = 33

assert hash(o1) == hash(o2) == hash(o3)

with lru_cached_fetch_item() as func:
    func(o1)  # <-- expecting miss
    func(o1)  # <-- expecting hit
    func(o2)  # <-- expecting hit BUT get a miss !
    func(o3)  # <-- expecting hit BUT get a miss !
    func(o3)  # <-- expecting hit

info = func.cache_info()
print(info) # CacheInfo(hits=4, misses=1, maxsize=128, currsize=1)
assert info.hits == 4
assert info.misses == 1
assert info.currsize == 1

Upvotes: 4

Related Questions