deathholes
deathholes

Reputation: 131

Using shared variables across sessions in tensorflow

I want to train a model and at the same time use the results of the model for further actions. The training can be done in the background, but I need the prediction model to be available all the time.

I've got an idea to how to do this but not sure if that is possible to do in tensorflow. So I'm thinking of creating separate threads/processes for prediction and training. There will be two different sessions running in each process and they will share the same variables. So, the training model can update the variables in it's own time and the prediction model can use the latest weights for better prediction.

Is there any way to share variable across sessions or some better way to do this? I've heard that it is dicouraged to run multiple sessions in tensorflow.

Upvotes: 3

Views: 744

Answers (1)

Dmitry Bufistov
Dmitry Bufistov

Reputation: 121

On same machine can you share session between "predict" and "train" threads? tf.Session().run() calls are thread safe. Here is a working example:

import tensorflow as tf
import numpy as np
import time
import threading

N = 128
input = tf.placeholder(tf.float32, shape=(None, N))
labels = tf.greater_equal(tf.reduce_sum(input, axis=-1, keepdims=True), 0)
l1size = 1024
fc1 = tf.contrib.layers.fully_connected(input, l1size)
l2size=128
fc2 = tf.contrib.layers.fully_connected(fc1, l2size)
predictions = tf.contrib.layers.fully_connected(fc2, 1,
                                    activation_fn=tf.nn.sigmoid)
loss = tf.losses.mean_squared_error(labels, predictions)
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

session = tf.Session()
session.run(tf.global_variables_initializer())

keep_going = True

def predict_thread(session):
    test_data = np.random.randn(10, N)
    while keep_going:
        current_loss = session.run(loss, feed_dict={input:test_data})
        print("Current loss: %f" % current_loss)
        time.sleep(1.)

def train_thread(session):
    train_data = np.random.randn(1024, N)
    while keep_going:
        session.run(train_op, feed_dict={input:train_data})

t1 = threading.Thread(target=train_thread, args=(session,))
t2 = threading.Thread(target=predict_thread, args=(session,))
t1.start()
t2.start()
time.sleep(10)
keep_going = False
t1.join()
t2.join()

You can also save/restore your model from time to time if training and prediction are on different machines. This question might be related.

Upvotes: 1

Related Questions