Ophir Yoktan
Ophir Yoktan

Reputation: 8449

Is it possible to visualize keras embeddings in tensorboard?

keras has the ability to export some of it's training data in a tensorboard comaptible format using keras.callbacks.TensorBoard

However, it doesn't support the embedding visualisation in tensorboard.

Is there a way around this?

Upvotes: 14

Views: 10721

Answers (3)

Michael Litvin
Michael Litvin

Reputation: 4126

This is now possible directly with the keras.callbacks.TensorBoard callback:

from keras import callbacks

model.fit(x_train, y_train,
        batch_size=batch_size,
        epochs=10,
        callbacks=[
                   callbacks.TensorBoard(batch_size=batch_size,
                                         embeddings_freq=3,  # Store embeddings every 3 epochs (this can be time consuming)
                                         embeddings_layer_names=['fc1', 'fc2'],  # Embeddings are taken from layers with names fc1 and fc2
                                         embeddings_metadata='metadata.tsv',  # This file will describe the embeddings data (see below)
                                         embeddings_data=x_test),  # Data used for the embeddings
                   ],
        )


# Use this metadata.tsv file before you have a trained model:
with open("metadata.tsv", 'w') as f:
    f.write("label\tidx\n")
    f.write('\n'.join(["{}\t{}".format(class_names[int(y.argmax())], i)
                       for i, y in enumerate(y_test)]))


# After the model is trained, you can update the metadata file to include more information, such as the predicted labels and the mistakes:
y_pred = model.predict(x_test)
with open("metadata.tsv", 'w') as f:
    f.write("label\tidx\tpredicted\tcorrect\n")
    f.write('\n'.join(["{}\t{}\t{}\t{}".format(class_names[int(y.argmax())],
                                               i,
                                               class_names[int(y_pred[i].argmax())],
                                               class_names[int(y.argmax())]==class_names[int(y_pred[i].argmax())])
                       for i, y in enumerate(y_test)]))

Note: Tensorboard will usually look for your metadata.tsv in the logs directory. If it doesn't find it, it will tell you at which path it was looking and you can copy it there and refresh tensorboard.

Upvotes: 1

Dmitry  Ziolkovskiy
Dmitry Ziolkovskiy

Reputation: 4054

There is pull request with this functionality - https://github.com/fchollet/keras/pull/5247 callback is extended to create visualization for specific embedding layers.

Upvotes: 5

Ophir Yoktan
Ophir Yoktan

Reputation: 8449

Found a solution:

import os

import keras
import tensorflow

ROOT_DIR = '/tmp/tfboard'

os.makedirs(ROOT_DIR, exist_ok=True)


OUTPUT_MODEL_FILE_NAME = os.path.join(ROOT_DIR,'tf.ckpt')

# get the keras model
model = get_model()
# get the tensor name from the embedding layer
tensor_name = next(filter(lambda x: x.name == 'embedding', model.layers)).W.name

# the vocabulary
metadata_file_name = os.path.join(ROOT_DIR,tensor_name)

embedding_df = get_embedding()
embedding_df.to_csv(metadata_file_name, header=False, columns=[])

saver = tensorflow.train.Saver()
saver.save(keras.backend.get_session(), OUTPUT_MODEL_FILE_NAME)

summary_writer = tensorflow.train.SummaryWriter(ROOT_DIR)

config = tensorflow.contrib.tensorboard.plugins.projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = tensor_name
embedding.metadata_path = metadata_file_name
tensorflow.contrib.tensorboard.plugins.projector.visualize_embeddings(summary_writer, config)

Upvotes: 9

Related Questions