Matt Camp
Matt Camp

Reputation: 1528

How do you decode one-hot labels in Tensorflow?

Been looking, but can't seem to find any examples of how to decode or convert back to a single integer from a one-hot value in TensorFlow.

I used tf.one_hot and was able to train my model but am a bit confused on how to make sense of the label after my classification. My data is being fed in via a TFRecords file that I created. I thought about storing a text label in the file but wasn't able to get it to work. It appeared as if TFRecords couldn't store text string or maybe I was mistaken.

Upvotes: 16

Views: 33078

Answers (4)

mrk
mrk

Reputation: 10366

tf.argmax is depreciated (all links within the answers on this page are thus 404) and now tf.math.argmax should be used .

Usage:

import tensorflow as tf
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmax(input = a)
c = tf.keras.backend.eval(b)
# c = 4
# here a[4] = 166.32 which is the largest element of a across axis 0

Note: You can also do this with numpy.

Upvotes: 0

Rochan
Rochan

Reputation: 1542

data = np.array([1, 5, 3, 8])
print(data)


def encode(data):
    print('Shape of data (BEFORE encode): %s' % str(data.shape))
    encoded = to_categorical(data)
    print('Shape of data (AFTER  encode): %s\n' % str(encoded.shape))
    return encoded


encoded_data = encode(data)
print(encoded_data)

def decode(datum):
    return np.argmax(datum)

decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
    datum = encoded_data[i]
    print('index: %d' % i)
    print('encoded datum: %s' % datum)
    decoded_datum = decode(encoded_data[i])
    print('decoded datum: %s' % decoded_datum)
    decoded_Y.append(decoded_datum)


print("****************************************")

print(decoded_Y)

Upvotes: 1

martianwars
martianwars

Reputation: 6500

You can find out the index of the largest element in the matrix using tf.argmax. Since your one hot vector will be one dimensional and will have just one 1 and other 0s, This will work assuming you are dealing with a single vector.

index = tf.argmax(one_hot_vector, axis=0)

For the more standard matrix of batch_size * num_classes, use axis=1 to get a result of size batch_size * 1.

Upvotes: 27

mrry
mrry

Reputation: 126154

Since a one-hot encoding is typically just a matrix with batch_size rows and num_classes columns, and each row is all zero with a single non-zero corresponding to the chosen class, you can use tf.argmax() to recover a vector of integer labels:

BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])

Upvotes: 13

Related Questions