George
George

Reputation: 944

Tensorflow table lookup int->float

Given a 2D Tensor of unknown dimensions [?, ?] containing integers (representing classes), I would like to obtain a new Tensor of the same shape, but with the values replaced by floats taken from a lookup table (representing class weights).

For example:

inputs = [ [1,3,3], [2,4,2] ]
lookup table: {1: 0.2, 2: 0.25, 3: 0.1, 4: 0.45}
output: [ [0.2, 0.1, 0.1], [0.25, 0.45, 0.25] ]

I have tried to chain two lambda functions with tf.map_fn, iterating over every row, then over every element:

elem_iter = lambda y: unknown_lookup_function(y)
row_iter = lambda x: elem_iter(x)
weights = tf.map_fn(row_iter, inputs, dtype=tf.float32)

but could not find a proper way of defining the lookup function. Any advice on how to implement this behaviour ? Is there a native op that I could use instead of map_fn ?

Upvotes: 3

Views: 2308

Answers (1)

greeness
greeness

Reputation: 16114

I think you want to use tf.gather:

The idea is that you store the lookup table as an array. At the index of i, you store the lookup value for input i. If your key is not integer but string, you would need to use index_table_from_file.

# Note I pad a dummpy element at index-0.
lookup_table = tf.constant([0, 0.2, 0.25, 0.1, 0.45])

inputs = tf.constant([ [1,3,3], [2,4,2] ])
output = tf.gather(lookup_table, inputs)
with tf.Session() as sess:
  print sess.run(output)

> 
  [[ 0.2         0.1         0.1       ]
   [ 0.25        0.44999999  0.25      ]]

Upvotes: 3

Related Questions