Reputation: 803
I want to set my LSTM hidden state in the generator. However, the set of the state only works outside the generator:
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # this works
def gen_data():
x = np.zeros((batch_size, num_steps, num_input))
y = np.zeros((batch_size, num_steps, num_output))
while True:
for i in range(batch_size):
K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # error
x[i, :, :] = X_train[gen_data.current_idx]
y[i, :, :] = Y_train[gen_data.current_idx]
gen_data.current_idx += 1
yield x, y
gen_data.current_idx = 0
The generator is invoked in the fit_generator
function:
model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None)
This is the result when I print the state:
print(model.layers[0].states[0])
<tf.Variable 'lstm/Variable:0' shape=(1, 2) dtype=float32>
This is the error that occurs in the generator:
ValueError: Tensor("Placeholder_1:0", shape=(1, 2), dtype=float32) must be from the same graph as Tensor("lstm/Variable:0", shape=(), dtype=resource)
What am I doing wrong?
Upvotes: 5
Views: 199
Reputation: 16876
Generators are multithreaded, so the graph used inside the generator will run in a different thread than that created the graph. So accessing the model form generator will access a different graph. A simple (but bad) solution is to force the generator to run in the same thread as the one the one that created the graph by setting workers=0
.
model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None, workers=0))
Debug Code:
def gen_data():
print ("-->",tf.get_default_graph())
while True:
for i in range(1):
yield (np.random.randn(batch_size, num_steps, num_input),
np.random.randn(batch_size, num_steps, 8))
model = get_model()
print (tf.get_default_graph())
model.fit_generator(gen_data(), 8, 1)
print (tf.get_default_graph())
Output
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--><tensorflow.python.framework.ops.Graph object at 0x14388e5c0>
Epoch 1/1
8/8 [==============================] - 4s 465ms/step - loss: 1.0198 - acc: 0.1575
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
You can see the graph objects are different. Making workers=0
will force the generator to run single threaded.
Using
model.fit_generator(gen_data(), 8, 1, workers=0)
results in
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--> <tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
Epoch 1/1
8/8 [==============================] - 4s 466ms/step - loss: 1.0373 - acc: 0.0975
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
same single threaded generator having access to the same graph.
However, to enable multi threaded generator an elegant method would be to save the graph to a variable in the main process creating the graph and pass it to the generator which uses the passed graph as the default graph.
Upvotes: 2