Reputation: 43
I am new to tf. I have trained an encoder - decoder using tensorflow. The program takes as input a word and prints out its phonemes.
For example: Hello World -> ['h', 'E', 'l', '"', '@U', ' ', 'w', '"', '3`', 'r', '5', 'd']
I would like to have access to the prediction probability of each phoneme chosen.
In the prediction section, the code I am using is the following:
def predict(words, sess):
if len(words) > hp.batch_size:
after = predict(words[hp.batch_size:], sess)
words = words[:hp.batch_size]
else:
after = []
x = np.zeros((len(words), hp.maxlen), np.int32) # 0: <PAD>
for i, w in enumerate(words):
for j, g in enumerate((w + "E")[:hp.maxlen]):
x[i][j] = g2idx.get(g, 2)
preds = np.zeros((len(x), hp.maxlen), np.int32)
for j in range(hp.maxlen):
xpreds = sess.run(graph.preds, {graph.x: x, graph.y: preds})
preds[:, j] = xpreds[:, j]
Thank you in advance!
My main problem is where these probabilities are "hidden" and how to access them. For example, the letter "o" in the word "Hello" was mapped with the phoneme "@U". I would like to find out with what probability "@U" was chosen as the ideal phoneme.
Upvotes: 2
Views: 2472
Reputation: 714
Following the discussion, I think I can point you to where the code should be changed. In train.py, line 104:
self.preds = tf.to_int32(tf.argmax(logits, -1))
They assign the preds variable to the index with highest probability. In order to get the softmax predictions, you can change the code as follows:
self.preds = tf.nn.softmax(logits)
I think that should do it.
How to view the probabilities:
preds = np.zeros((len(x), hp.maxlen), np.float32)
for j in range(hp.maxlen):
xpreds = sess.run(graph.preds, {graph.x: x, graph.y: preds})
# print shape of output -> batch_size, max_length,number_of_output_options
print(xpreds.shape)
# print all predictions of the first output
print(xpreds[0, 0])
# print the probabilty of the network prediction
print(xpreds[0, 0, np.argmax(xpreds[0][0])])
# preds[:, j] = _preds[:, j] Need to accumulate the results according to the correct output shape
Upvotes: 1