John Scolaro
John Scolaro

Reputation: 702

Pass additional variables to a function variable

I am using a function in TensorFlow which maps a set of tensors to another arrangement of tensors. For example, you might write:

data = data.map(_function)

def _function(a, b, c):
    return (a + 1, b, c)

So here, you pass _function as a function variable to map, and map passes it three tensors, which are mutated in some way (here, just adding one) and returned.


My question is: Is there a way to pass in additional variables to _function?

If I want to perform a + x, and not a + 1, then how could I pass in the additional variable?

You can't do something like: data.map(_function(x)) because then you're passing the result of a function, not the function itself.

I've experimented with *arg, but I can't find a way. Any help is greatly appreciated.

Upvotes: 0

Views: 63

Answers (1)

meili
meili

Reputation: 301

You can do sth like

def extra_func(x):
    def _function(a, b, c):
        return (a + x, b, c)
    return _function

So you can do data.map(extra_func(x))

or you can use functools.partial to fix some of a function params

Upvotes: 3

Related Questions