ElleryL
ElleryL

Reputation: 527

tf.py_function not able to return a list?

consider

>>> f = lambda x:[x+1,x+2,x+3,x+4]
>>> tf.py_function([1],[tf.int32])
    [<tf.Tensor: id=7370, shape=(), dtype=int32, numpy=2>]

However, when I do this

>>> f = lambda x,y:([x+1,x+2,x+3,x+4,x+5],[x+1,x+2,x+3,x+4,x+5])
>>> tf.py_function(f,[1,1],[tf.int32,tf.int64])
    [<tf.Tensor: id=7509, shape=(5,), dtype=int32, numpy=array([2, 3, 4, 5, 6])>,
     <tf.Tensor: id=7510, shape=(5,), dtype=int64, numpy=array([2, 3, 4, 5, 6], dtype=int64)>]

I've found this to be weird. At first example, I thought it doesn't return a list but only return its first value because my return type Tout=tf.int32. So it assumes it only returns a integer not list.

However, at my second example, where Tout=[tf.int32,tf.int64] it returns two list; first list tf.int32 and second list tf.int64 indicates that tf.int32 doesn't imply just return integer value; it still can represent a list of integers.

Anyone knows how to properly fix it so it returns valid list of values?

Upvotes: 2

Views: 1957

Answers (1)

user11530462
user11530462

Reputation:

It is actually pretty simple as below,

Fixed Code -

f = lambda x:[[x+1,x+2,x+3,x+4]]
tf.py_function(f,[1],[tf.int32])

Output -

[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 3, 4, 5], dtype=int32)>]

Hope this answers your question. Happy Learning.

Upvotes: 5

Related Questions