Reputation: 702
I am using a function in TensorFlow which maps a set of tensors to another arrangement of tensors. For example, you might write:
data = data.map(_function)
def _function(a, b, c):
return (a + 1, b, c)
So here, you pass _function as a function variable to map, and map passes it three tensors, which are mutated in some way (here, just adding one) and returned.
My question is: Is there a way to pass in additional variables to _function?
If I want to perform a + x
, and not a + 1
, then how could I pass in the additional variable?
You can't do something like: data.map(_function(x))
because then you're passing the result of a function, not the function itself.
I've experimented with *arg, but I can't find a way. Any help is greatly appreciated.
Upvotes: 0
Views: 63
Reputation: 301
You can do sth like
def extra_func(x):
def _function(a, b, c):
return (a + x, b, c)
return _function
So you can do data.map(extra_func(x))
or you can use functools.partial to fix some of a function params
Upvotes: 3