beaver789
beaver789

Reputation: 1

Custom Keras Loss Function is getting tensors of different shapes

I've created a basic custom loss function (that returns either MSE or MAE of y_true and y_pred). However, I am getting the following error:

InvalidArgumentError: Input to reshape is a tensor with 32 values, but the requested shape has 1 [[node custom_loss/Reshape

The shape of the train_features tensor I pass into model.fit() is (44,906,1), and my Sequential model is:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LocallyConnected1D(4, kernel_size=(4), strides=2,
      activation='relu', input_shape= train_features.shape[1:], padding='valid'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1, activation='relu'))

I'm unsure why a tensor with 32 values is being reshaped (and what that has to do with my custom loss function, since my model works when I am not using my custom loss function):

def custom_loss(y_true,y_pred):
    if y_true >= y_pred:
      return keras.backend.square(y_pred - y_true)
    else:
      return keras.backend.abs(y_pred-y_true)

And my understanding is that my model's layers output tensors should be alright (otherwise the model wouldn't be running fine if I simply set loss='mse' in model.compile()).

Upvotes: 0

Views: 457

Answers (1)

BestDogeStackoverflow
BestDogeStackoverflow

Reputation: 1117

Your loss function is returning a tensor with 32 values, this is happening because you are only using bitwise operation, and the output has the same shape of the input, the neural Net can't do regression with that, you need to generate a tensor with only ONE value, you need to lower the dimension to (1). for example like this will work:

import keras.backend as K

def my_mse(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

And an other note: you can't put if statement is your loss funtion, the network can't do regression with that, the if has no gradient, remove the if and lower the dimension of the output tensor and it will work.

Upvotes: 0

Related Questions