GoingMyWay
GoingMyWay

Reputation: 17468

Tensorflow, update weights in multiprocessing

I defined a network and each scope contains weights for each process, each process assign its corresponding weights, here is my demo code

from multiprocessing import Process

import tensorflow as tf


def init_network(name):
    with tf.name_scope(name):
        x = tf.Variable(int(name))
        return x


def f(name, sess):
    print('step into f()')
    vars = tf.trainable_variables(name)
    print(sess.run(vars[0]))
    sess.run(vars[0].assign(int(name)+10))


if __name__ == '__main__':
    sess = tf.Session()
    x1 = init_network('1')
    x2 = init_network('2')
    sess.run(tf.global_variables_initializer())
    p1 = Process(target=f, args=('1', sess))
    p2 = Process(target=f, args=('2', sess))

    p1.start()
    p2.start()

    p1.join()
    p2.join()
    print(sess.run([x1, x2]))

The demo code get stuck and it seems that sess can't be shared within different processes, how can I update weights in multiprocessing settings?

Upvotes: 2

Views: 551

Answers (1)

GoingMyWay
GoingMyWay

Reputation: 17468

After googling for a while, I found multiprocessing doesn't work with TensorFlow, so, instead, I use threading.

from threading import Thread

import tensorflow as tf

def init_network(name):
    with tf.name_scope(name):
        x = tf.Variable(int(name))
        return x

def f(name, sess):
    with sess.as_default(), sess.graph.as_default():
        print('step into f()')
        vars = tf.trainable_variables(name)
        print(vars)
        sess.run(vars[0].assign(int(name)+10))
        print(sess.run(vars[0]))


if __name__ == '__main__':
    sess = tf.Session()
    coord = tf.train.Coordinator()

    x1 = init_network('1')
    x2 = init_network('2')
    sess.run(tf.global_variables_initializer())
    print(sess.run([x1, x2]))

    p1 = Thread(target=f, args=('1', sess))
    p2 = Thread(target=f, args=('2', sess))
    p1.start()
    p2.start()
    coord.join([p1, p2])
    print(sess.run([x1, x2]))

It works now, the default session is a property of the current thread. If you create a new thread and wish to use the default session in that thread, you must explicitly add a with sess.as_default(): in that thread's function. And you must explicitly enter a with sess.graph.as_default(): block to make sess.graph the default graph.

tf.train.Coordinator is much convenient to join threads. One can also use thread.join() method to join threads.

Upvotes: 2

Related Questions