Reputation: 1987
I currently have onehot encoddings that I want to use embeddings for. However when I call
embed=tf.nn.embedding_lookup(embeddings, train_data)
print(embed.get_shape())
embed data shape (11, 32, 729, 128)
This shape should be (11, 32, 128) but it gives me the wrong dimensions because train_data is onehot encoded.
train_data2=tf.matmul(train_data,tf.range(729))
give me error:
ValueError: Shape must be rank 2 but is rank 3
Help me out please! Thanks.
Upvotes: 4
Views: 1437
Reputation: 554
A small fix to your example:
encoding_size = 4
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.matmul(one_hot_batch, np.array([range(encoding_size)],
dtype=np.int32).T)
with tf.Session() as session:
print one_hot_indexes.eval()
Another way:
batch_size = 3
one_hot_batch = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0]])
one_hot_indexes = tf.where(tf.not_equal(one_hot_batch, 0))
one_hot_indexes = one_hot_indexes[:, 1]
one_hot_indexes = tf.reshape(one_hot_indexes, [batch_size, 1])
with tf.Session() as session:
print one_hot_indexes.eval()
Result in both cases:
[[3]
[1]
[0]]
Upvotes: 2