MerklT
MerklT

Reputation: 803

Setting Keras Variables in Generator

I want to set my LSTM hidden state in the generator. However, the set of the state only works outside the generator:

K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # this works

def gen_data():
    x = np.zeros((batch_size, num_steps, num_input))
    y = np.zeros((batch_size, num_steps, num_output))
    while True:
        for i in range(batch_size):
            K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # error
            x[i, :, :] = X_train[gen_data.current_idx]
            y[i, :, :] = Y_train[gen_data.current_idx]
            gen_data.current_idx += 1
        yield x, y
gen_data.current_idx = 0

The generator is invoked in the fit_generator function:

model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None)

This is the result when I print the state:

print(model.layers[0].states[0])
<tf.Variable 'lstm/Variable:0' shape=(1, 2) dtype=float32>

This is the error that occurs in the generator:

ValueError: Tensor("Placeholder_1:0", shape=(1, 2), dtype=float32) must be from the same graph as Tensor("lstm/Variable:0", shape=(), dtype=resource)

What am I doing wrong?

Upvotes: 5

Views: 199

Answers (1)

mujjiga
mujjiga

Reputation: 16876

Generators are multithreaded, so the graph used inside the generator will run in a different thread than that created the graph. So accessing the model form generator will access a different graph. A simple (but bad) solution is to force the generator to run in the same thread as the one the one that created the graph by setting workers=0.

model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None, workers=0))

Debug Code:

def gen_data():
    print ("-->",tf.get_default_graph())
    while True:
        for i in range(1):
            yield (np.random.randn(batch_size, num_steps, num_input), 
            np.random.randn(batch_size, num_steps, 8))

model = get_model()
print (tf.get_default_graph())
model.fit_generator(gen_data(), 8, 1)
print (tf.get_default_graph())

Output

<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--><tensorflow.python.framework.ops.Graph object at 0x14388e5c0>
Epoch 1/1 
8/8 [==============================] - 4s 465ms/step - loss: 1.0198 - acc: 0.1575
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>

You can see the graph objects are different. Making workers=0 will force the generator to run single threaded.

Using

model.fit_generator(gen_data(), 8, 1, workers=0)

results in

<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--> <tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
Epoch 1/1
8/8 [==============================] - 4s 466ms/step - loss: 1.0373 - acc: 0.0975
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>

same single threaded generator having access to the same graph.

However, to enable multi threaded generator an elegant method would be to save the graph to a variable in the main process creating the graph and pass it to the generator which uses the passed graph as the default graph.

Upvotes: 2

Related Questions