Mika
Mika

Reputation: 239

Neural net inside for loop

I have smth like this

for q in range(10):
   # generate some samples
   x = Input(batch_shape=(n_batch, xx.shape[1]))
   x = Dense(20)(x)
   x = LeakyReLU(alpha=0.001)(x)
   y = Dense(1)(x)
   y = LeakyReLU(alpha=0.001)(y)
   model = Model(inputs=x, outputs=y) 
   model.compile(loss='mean_squared_error', optimizer='Adam', metrics=['accuracy'])
   for i in range(10):
      model.fit(x, y, epochs=1, batch_size=n_batch, verbose=0, shuffle=False)
      model.reset_states()

I wonder if the neural net is built from scratch for every q or it retains everything from the previous q? If it retains, how do I reset and build, compile and fit neural net separately for every q?

Upvotes: 1

Views: 588

Answers (1)

Mete Han Kahraman
Mete Han Kahraman

Reputation: 760

When you make a layer with keras or tensorflow, tensorflow adds a node or more to its graph, every time you add an optimizer, a loss function or an activation function it does the same thing and adds a node for them.

When you call model.fit() tensorflow executes its graph starting from its root. If you add your nodes in a loop, previous ones will not be deleted. They will take space in memory, and they will lower your performance.

What to do instead? Its very simple, re-initialize your weights and re-use the same nodes. Your code won't change much just move the sample generation down with the for loop and define a function to re-initialize.

I also took the second for loop down and just increased the epoch number to 10, you can put that for loop back if you have a reason to have it there.

def reset_weights(model):
    session = K.get_session()
    for layer in model.layers: 
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)

x = Input(batch_shape=(n_batch, xx.shape[1]))
x = Dense(20)(x)
x = LeakyReLU(alpha=0.001)(x)
y = Dense(1)(x)
y = LeakyReLU(alpha=0.001)(y)
model = Model(inputs=x, outputs=y) 
model.compile(loss='mean_squared_error', optimizer='Adam', metrics=['accuracy'])
for q in range(10):
    #generate some samples
    model.fit(x, y, epochs=10, batch_size=n_batch, verbose=1, shuffle=False)
    model.reset_states()
    reset_weights(model)

Upvotes: 1

Related Questions