Max
Max

Reputation: 678

Make built-in lru_cache skip caching when function returns None

Here's a simplified function for which I'm trying to add a lru_cache for -

from functools import lru_cache, wraps

@lru_cache(maxsize=1000)
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.cache_info())

outputs -

CacheInfo(hits=0, misses=1000, maxsize=1000, currsize=1000)

As we can see, it would also cache args and returned values for the None returns as well. In above example, I want the cache_size to be 334, where we are returning non-None values. In my case, my function having large no. of args might return a different value if previous value was None. So I want to avoid caching the None values.

I want to avoid reinventing the wheel and implementing a lru_cache again from scratch. Is there any good way to do this?

Here are some of my attempts -

1. Trying to implement own cache (which is non-lru here) -

from functools import wraps 

# global cache object
MY_CACHE = {}

def get_func_hash(func):
    # generates unique key for a function. TODO: fix what if function gets redefined?
    return func.__module__ + '|' + func.__name__

def my_lru_cache(func):
    name = get_func_hash(func)
    if not name in MY_CACHE:
        MY_CACHE[name] = {}
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        if tuple(args) in MY_CACHE[name]:
            return MY_CACHE[name][tuple(args)]
        value = func(*args, **kwargs)
        if value is not None:
            MY_CACHE[name][tuple(args)] = value
        return value
    return function_wrapper

@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(get_func_hash(validate_token))
print(len(MY_CACHE[get_func_hash(validate_token)]))

outputs -

__main__|validate_token
334

2. I realised that the lru_cache doesn't do caching when an exception is raised within the wrapped function -

from functools import wraps, lru_cache

def my_lru_cache(func):
    @wraps(func)
    @lru_cache(maxsize=1000)
    def function_wrapper(*args, **kwargs):
        value = func(*args, **kwargs)
        if value is None:
            # TODO: change this to a custom exception
            raise KeyError
        return value
    return function_wrapper

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None
    return function_wrapper    

@handle_exception
@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.__wrapped__.cache_info())

outputs -

CacheInfo(hits=0, misses=334, maxsize=1000, currsize=334)

Above correctly caches only the 334 values, but needs wrapping the function twice and accessing the cache_info in a weird manner func.__wrapped__.cache_info().

How do I better achieve the behaviour of not caching when None(or specific) values are returned using built-in lru_cache decorator in a pythonic way?

Upvotes: 9

Views: 3972

Answers (3)

Raymond Hettinger
Raymond Hettinger

Reputation: 226171

One way is to have the cached function raise an exception and then create a wrapper that converts the exception back to a None value.

@lru_cache(maxsize=1000)
def _validate_token(token):
    if token % 3:
        raise ValueError
    return True

def validate_token(token):
    try:
        return _validate_token(token)
    except ValueError:
        return None

for x in range(1000):
    validate_token(x)

print(_validate_token.cache_info())

This outputs:

CacheInfo(hits=0, misses=1000, maxsize=1000, currsize=334)

If the original function cannot be modified, a decorator can be added for the none-to-exception conversion.

def convert_none_to_value_error(func):
    def inner(*args):
        result = func(*args)
        if result is None:
            raise ValueError
        return result
    return inner

@lru_cache(maxsize=1000)
@convert_none_to_value_error
def validate_token(token):
    if token % 3:
        return None
    return True

If this needs to be done for more than one function, the work can be done with a pair of decorators.

class NoneResult(ValueError):
    pass

def convert_none_to_exc(func):
    def inner(*args):
        result = func(*args)
        if result is None:
            raise NoneResult()
        return result
    return inner

def convert_exc_to_none(func):
    def inner(*args):
        try:
            return func(*args)
        except NoneResult:
            return None
    inner.cache_info = func.cache_info
    return inner

@convert_exc_to_none
@lru_cache(maxsize=1000)
@convert_none_to_exc
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)


print(validate_token.cache_info())

Upvotes: 0

Starlight X
Starlight X

Reputation: 81

Use exceptions to prevent caching:

from functools import lru_cache

@lru_cache(maxsize=None)
def fname(x):
    print('worked')

    raise Exception('')

    return 1


for _ in range(10):
    try:
        fname(1)
    except Exception as e:
        pass

In example above "worked" will be printed 10 times.

Upvotes: 7

aaron
aaron

Reputation: 43073

You are missing the two lines marked here:

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None

    function_wrapper.cache_info = func.cache_info    # Add this
    function_wrapper.cache_clear = func.cache_clear  # Add this
    return function_wrapper

You can do both wrappers in one function:

def my_lru_cache(maxsize=128, typed=False):
    class CustomException(Exception):
        pass

    def decorator(func):
        @lru_cache(maxsize=maxsize, typed=typed)
        def raise_exception_wrapper(*args, **kwargs):
            value = func(*args, **kwargs)
            if value is None:
                raise CustomException
            return value

        @wraps(func)
        def handle_exception_wrapper(*args, **kwargs):
            try:
                return raise_exception_wrapper(*args, **kwargs)
            except CustomException:
                return None

        handle_exception_wrapper.cache_info = raise_exception_wrapper.cache_info
        handle_exception_wrapper.cache_clear = raise_exception_wrapper.cache_clear
        return handle_exception_wrapper

    if callable(maxsize):
        user_function, maxsize = maxsize, 128
        return decorator(user_function)

    return decorator

Upvotes: 5

Related Questions