Reputation: 11
How to use lru_cache for static method and with unhashable arguments like lists in python
I have tried using methodtools lru_cache method. It gives an error that call is not working.
Upvotes: 1
Views: 313
Reputation: 226624
Put the @staticmethod
decorator just above the cache decorator.
class C:
@staticmethod
@lru_cache
def slow_func(a, b, c):
return len(str(a ** b ** c))
Create an adapter function that converts inputs to hashable arguments.
from math import gcd
from itertools import permutations
from typing import Iterable
@lru_cache
def slow_function(data: tuple) -> int:
"Compute the number of coprime pairs."
return sum(gcd(p, q) == 1 for p, q in permutations(data, 2))
def adapter(data: Iterable) -> int:
return slow_function(tuple(data))
# Call adapter() with lists which are unhashable
print(adapter([10, 20, 30, 22, 39])) # Cache miss
print(adapter([11, 23, 35, 45, 52])) # Cache miss
print(adapter([10, 20, 30, 22, 39])) # Cache hit
Typical conversions:
unhashable = [10, 20, 30] # list
hashable = tuple(unhashable)
unhashable = {10, 20, 30} # set
hashable = frozenset(unhashable)
unhashable = {10: 'ten', 20: 'twenty'} # dict
hashable = tuple(sorted(unhashable.items()))
Upvotes: 0
Reputation: 13438
You could use a separate decorator to sanitize the inputs. Something like this:
def sanitize_args(*expected_types):
def make_decorator(decorated):
@functools.wraps(decorated)
def decorator(*args):
if len(args) != len(expected_types):
raise TypeError("Wrong number of arguments")
args = (type_(arg) for type_, arg in zip(expected_types, args))
return decorated(*args)
return decorator
return make_decorator
class Foo:
@staticmethod
@sanitize_args(tuple, str)
@functools.lru_cache
def bar(sequence, label):
return (label, sum(sequence))
print(Foo.bar([1, 2, 3], "foo"))
However, getting this right (and fast) in a generic fashion is a bit tedious. Note how I left out keyword arguments for simplicity.
An easier solution is to use an uncached public interface with an lru-cached private implementation. Something like this:
class Foo:
@staticmethod
@functools.lru_cache
def _implement_bar(sequence, label):
return (label, sum(sequence))
@staticmethod
def bar(sequence, label):
return Foo._implement_bar(tuple(sequence), str(label))
Upvotes: 0