Mark Lavin
Mark Lavin

Reputation: 1242

Wny is the loss during fitting of a Tensorflow Probabilistic model always ```nan```

I am working on a Tensorflow probabilistic model to recognize digits in the MNIST dataset. I am encountering the problem that during fitting of the model, it always reports the loss to be nan. Here is the definition of the loss function:

def nll(y_true, y_pred):
    # This function should return the negative log-likelihood of each sample
    # in y_true given the predicted distribution y_pred.
    result =  - y_pred.log_prob( y_true )
    return result

Here's the code that generates and compiles the model:

def get_probabilistic_model(input_shape, loss, optimizer, metrics):
    # This function should return the compiled probabilistic model
    model = Sequential([
        Conv2D( filters = 8, kernel_size = ( 5, 5 ), activation = "relu", padding = "valid", input_shape = input_shape ),
        MaxPooling2D( pool_size = ( 6, 6 ) ),
        Flatten(),
        Dense( units = 10, activation = None ),
        tfpl.DistributionLambda( lambda t:  tfd.OneHotCategorical( probs = t ) )
    ])
    model.compile( loss = loss, optimizer = optimizer, metrics = metrics )
    return model

tf.random.set_seed(0)
probabilistic_model = get_probabilistic_model(
    input_shape=(28, 28, 1), 
    loss=nll, 
    optimizer=RMSprop(), 
    metrics=['accuracy']
)

probabilistic_model.summary() confirms that the model is being built and compiled successfully. However, when I try to fit the model, I get the following:

probabilistic_model.fit(x_train, y_train_oh, epochs=5)

Train on 60000 samples
Epoch 1/5
20544/60000 [=========>....................] - ETA: 40s - loss: nan - accuracy: 0.0978

The problem is that the loss is nan (not a number), and the accuracy seems quite small, not surprising given that the loss is messed up.

I've been trying to figure out what is wrong with either the definition of the loss function, or the construction / compilation of the model, or what.

Upvotes: 1

Views: 309

Answers (1)

user11530462
user11530462

Reputation:

From comments

You need tfpl.OneHotCategorical(num_classes) for the last layer. OneHotCategorical already inherits from DistributionLambda (paraphrased from Frightera)

Upvotes: 2

Related Questions