Reputation: 17468
I defined a network and each scope contains weights for each process, each process assign its corresponding weights, here is my demo code
from multiprocessing import Process
import tensorflow as tf
def init_network(name):
with tf.name_scope(name):
x = tf.Variable(int(name))
return x
def f(name, sess):
print('step into f()')
vars = tf.trainable_variables(name)
print(sess.run(vars[0]))
sess.run(vars[0].assign(int(name)+10))
if __name__ == '__main__':
sess = tf.Session()
x1 = init_network('1')
x2 = init_network('2')
sess.run(tf.global_variables_initializer())
p1 = Process(target=f, args=('1', sess))
p2 = Process(target=f, args=('2', sess))
p1.start()
p2.start()
p1.join()
p2.join()
print(sess.run([x1, x2]))
The demo code get stuck and it seems that sess
can't be shared within different processes, how can I update weights in multiprocessing settings?
Upvotes: 2
Views: 551
Reputation: 17468
After googling for a while, I found multiprocessing
doesn't work with TensorFlow, so, instead, I use threading
.
from threading import Thread
import tensorflow as tf
def init_network(name):
with tf.name_scope(name):
x = tf.Variable(int(name))
return x
def f(name, sess):
with sess.as_default(), sess.graph.as_default():
print('step into f()')
vars = tf.trainable_variables(name)
print(vars)
sess.run(vars[0].assign(int(name)+10))
print(sess.run(vars[0]))
if __name__ == '__main__':
sess = tf.Session()
coord = tf.train.Coordinator()
x1 = init_network('1')
x2 = init_network('2')
sess.run(tf.global_variables_initializer())
print(sess.run([x1, x2]))
p1 = Thread(target=f, args=('1', sess))
p2 = Thread(target=f, args=('2', sess))
p1.start()
p2.start()
coord.join([p1, p2])
print(sess.run([x1, x2]))
It works now, the default session is a property of the current thread. If you create a new thread and wish to use the default session in that thread, you must explicitly add a with sess.as_default():
in that thread's function. And you must explicitly enter a with sess.graph.as_default():
block to make sess.graph
the default graph.
tf.train.Coordinator
is much convenient to join threads. One can also use thread.join()
method to join threads.
Upvotes: 2