Perschi
Perschi

Reputation: 21

NAN values with SGD optimizer in Keras for regression NN

I try to train a NN for regression. When using SGD optimizer class from Keras I suddently get NAN values as prediction from my network after the first step. Before I was running trainings with the Adam optimizer class and everything worked fine. I already tried to change the learning rate of SGD but still NAN values occure as model prediction after the first step and after compiling.

Since my training worked with Adam optimizer I don't believe my inputs are causing the NAN's. I already checked my input valus for NaNs and removed all of them. So what could cause this behavior?

Here is my code:

from keras.optimizers import Adam
from keras.optimizers import SGD
model = Sequential()

model.add(Dense(300,input_shape=(50,), kernel_initializer='glorot_uniform', activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(300, kernel_initializer='glorot_uniform', activation='relu')) model.add(Dropout(0.3)) 
model.add(Dense(500, kernel_initializer='glorot_uniform', activation='relu')) model.add(Dropout(0.3)) 
model.add(Dense(400, kernel_initializer='glorot_uniform', activation='relu')) model.add(Dense(1, kernel_initializer='glorot_uniform', activation='linear'))

opt = SGD(lr=0.001, decay=1e-6)

model.compile(loss='mse', optimizer=opt)

model.fit(x_train, y_train, epochs=100, batch_size=32, verbose=0, validation_data=(x_test, y_test))

#print(type(x_train)) ='pandas.core.frame.DataFrame'>
#print( x_train.shape) = (10000 , 50)

Upvotes: 1

Views: 1444

Answers (1)

Frightera
Frightera

Reputation: 5079

Using ANNs for regression is a bit tricky as outputs don't have an upper bound.

The NaNs in the loss function is mostly likely because you have exploding gradients. The reason that it does not show NaN when you use Adam is that Adam adapts the learning rate. Adam works most of the times, so avoid using SGD as long as you don't have a specific reason.

I am not sure what your dataset contains but, you can try:

  • Adding L2 regularization
  • Normalizing inputs
  • Increasing batch size.

Upvotes: 1

Related Questions