Tomáš Hons
Tomáš Hons

Reputation: 375

Storing arguments of function with later evaluation

I have a pretty specific question.

User gives me a function f and a "variables" as arguments of the function. Note, that the signature of f may be arbitrary (i.e. can use any legal combination of 5 kinds of parameters) and some arguments may have an default value. The variables are, in a simplified way, keys to dictionary which translates variables to real values.

I would like do the following things:

My overall goal is to determine whether the two calls of f return the same result (provided that f is pure). Therefore, it is crucial to determine (for example) whether two arguments set differ only in default value of f (in one case being passed explicitly).

Example

There will be roughly the following class:

class VariableCapture:
    translation_table = ...

    def __init__(self, func):
        self._func = func

    def set_args(self, ..# someting ..):
        # capture arguments
        # compare with signature of self._func
        # save the actual arguments
    
    def call(self):
        self._func(..# something ..)

Here are an example of f:

def f(a,b):
    # some magic here, not really my business

vc = VariableCapture(f)
vc.set_args(1, 2, 3) # error -- too many args
vc.set_args(a=1, c=2) # error -- keyword arg with wrong name
vc.set_args(1, 2) # ok
vc.set_args(a=1, b=2) # ok

See, however that there might be difference between two last lines in case of the following f:

def f(*args, **kwargs):
    # some magic here, not really my business

After vc.set_args(1, 2) the later call of f should passed both arguments via *args. The later via **kwargs.

My guesses

I think that it is possible to receive any kind of argument specification via set_args(self, *args, **kwargs). However, I am not sure with that.

Also, I don't know, whether I can differentiate between non-equivalent ways how to pass the parameters (from perspective of the particular f).

And, closely related to the previous, whether call of f via statement `self._func(*self.saved_args, *self.saved_kwargs).

I admit, it's rather broad question, but any tips or suggestions are appreciated. Also, any references to libraries which might help are great (I suppose that the key library here is inspect). Thank you!

Upvotes: 3

Views: 502

Answers (2)

Yam Mesicka
Yam Mesicka

Reputation: 6581

You can use inspect, as you mentioned, to check the number of arguments.

We should also take into account the possibility of default arguments.

import inspect
import pytest


class VariableCapture:
    def __init__(self, func):
        self._func = func
        self._args = []
        self._kwargs = {}

    @property
    def _argspec(self):
        return inspect.getfullargspec(self._func)

    def _is_args_ok(self, args):
        maximum_args = len(self._argspec.args)
        minimum_args = maximum_args - len(self._argspec.defaults or [])
        return (
            minimum_args <= len(args) <= maximum_args
            or (minimum_args <= len(args) and self._argspec.varargs)
        )

    def _is_kwargs_ok(self, args, kwargs):
        real_kwargs = set(self._argspec.kwonlyargs)
        required_kwargs = real_kwargs - set(self._argspec.kwonlydefaults or {})
        our_kwargs = set(kwargs) - set(self._argspec.args)
        missing_kwargs = required_kwargs - our_kwargs
        excessive_kwargs = our_kwargs - real_kwargs
        extra_kwargs = bool(self._argspec.varkw)
        return (
            not missing_kwargs and (not excessive_kwargs or extra_kwargs)
        )

    def _get_excessive_kwargs(self, args, kwargs):
        original_args = self._argspec.args + list(self._argspec.defaults or ())
        provided_args = set(original_args[:len(args)])
        return set(kwargs) & provided_args

    def set_args(self, *args, **kwargs):
        if not self._is_args_ok(args):
            raise TypeError(
                f"Setting {self._func.__name__} with wrong number of args. "
                f"Expected {len(self._argspec.args)}: {self._argspec.args}."
            )
        if not self._is_kwargs_ok(args, kwargs):
            raise TypeError(
                f"Setting {self._func.__name__} with wrong kwargs. "
                f"Got {list(self._kwargs)}, "
                f"instead of {self._argspec.kwonlyargs} "
                f"(defaults: {self._argspec.kwonlydefaults})."
            )
        if excessive_kwargs := self._get_excessive_kwargs(args, kwargs):
            raise TypeError(f"Some kwargs are duplicate: {excessive_kwargs}.")
        self._args = args
        self._kwargs = kwargs

    def call(self):
        return self._func(*self._args, **self._kwargs)

Some tests:

def test_regular():
    def func(a, b):
        return a + b

    v = VariableCapture(func)
    with pytest.raises(TypeError):
        v.set_args(1)
    with pytest.raises(TypeError):
        v.set_args(1, 2, 3)

    v.set_args(1, 2)
    assert v.call() == 3  # Good to go


def test_args():
    def func(a, b, *args):
        return a + b + sum(args)

    v = VariableCapture(func)
    with pytest.raises(TypeError):
        v.set_args(1)

    v.set_args(1, 2)
    assert v.call() == 3
    v.set_args(1, 2, 3, 4)
    assert v.call() == 10


def test_default():
    def func(a, b, c=5, d=1):
        return a + b + c + d

    v = VariableCapture(func)
    with pytest.raises(TypeError):
        v.set_args(1)
    v.set_args(1, 2)
    assert v.call() == 9
    v.set_args(1, 2, 3)
    assert v.call() == 7
    v.set_args(1, 2, 3, 4)
    assert v.call() == 10
    v.set_args(1, 2, d=3)
    assert v.call() == 11
    with pytest.raises(TypeError):
        v.set_args(1, c=3)
    with pytest.raises(TypeError):
        v.set_args(1, 2, e=3)
    with pytest.raises(TypeError):
        v.set_args(1, 2, 3, 4, 5)


def test_kwargs():
    def func(a, b, *, c, d=1, **kwargs):
        return a + b + c + d + kwargs.get('e', 0)

    v = VariableCapture(func)
    with pytest.raises(TypeError):
        v.set_args(1)
    with pytest.raises(TypeError):
        v.set_args(1, 2)
    with pytest.raises(TypeError):
        v.set_args(1, 2, 3)
    with pytest.raises(TypeError):
        v.set_args(1, 2, d=3)
    v.set_args(1, 2, c=3)
    assert v.call() == 7
    v.set_args(1, 2, c=3, d=4)
    assert v.call() == 10
    v.set_args(1, 2, c=3, d=4, e=5)
    assert v.call() == 15
    v.set_args(1, 2, c=3, e=5)
    assert v.call() == 12
    with pytest.raises(TypeError):
        v.set_args(1, c=3)
    with pytest.raises(TypeError):
        v.set_args(1, 2, b=2, c=5)  # A tough one, identify duplications
    with pytest.raises(TypeError):
        v.set_args(1, 2, e=3)
    with pytest.raises(TypeError):
        v.set_args(1, 2, 3, 4, 5)

Upvotes: 3

MegaIng
MegaIng

Reputation: 7886

I am actually surprised how short this answer is:

class VariableCapture:
    def __init__(self, func):
        self._func = func
        self._sign = inspect.signature(func)
        self._bound = None

    def set_args(self, *args, **kwargs):
        self._bound = self._sign.bind(*args, **kwargs)
        self._bound.apply_defaults()

    def call(self):
        assert self._bound is not None
        return self._func(*self._bound.args, **self._bound.kwargs)

This does all you want. To do something with the actual arguments (including defaults) use self._bound.arguments, which is a dictionary containing all passed and default values.

(Note that there is one theoretical difference, which should be impossible to be detected from anything but 'malicious' c code: The actual call is actually passing the defaults and not letting python automatically handle them.)

Upvotes: 4

Related Questions