Hari Krishnan
Hari Krishnan

Reputation: 2079

Python crashing when I'm running the program with a custom loss function

Whenever I run this code, Python stops working. To make sure that this isn't my system's issue, I tried running it in Google Colab, which crashed too. The crash happens when the execution reached model.fit line

data = load_iris()
X = data['data']
y = data['target']
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.33, random_state =23)

def energy(x):
    val,vec = tf.linalg.eigh(x)
    en = tf.reduce_sum(tf.math.square(val))
    return en

def energy_loss(y_actual,y_predicted):
    mtm_actual = tf.linalg.matmul(y_actual,tf.transpose(y_actual))
    ptp_actual = tf.linalg.matmul(y_predicted,tf.transpose(y_predicted))
    actual_energy = energy(y_actual)
    predicted_energy = energy(y_predicted)
    return tf.math.abs(actual_energy - predicted_energy)

model = Sequential()
model.add(Dense(32,input_dim=4))
model.add(Dense(64,activation='relu'))
model.add(Dense(128,activation='relu'))
model.add(Dense(64,activation='relu'))
model.add(Dense(3,activation='relu'))
opt = Adam(lr = 1)
model.compile(optimizer = opt, loss=energy_loss,metrics=['accuracy'])

model.fit(X_train,X_train,epochs=25,verbose = 1,batch_size = 5)

I am using tensorflow 1.15.0 when running this code. Any ideas what is causing this problem ?

Upvotes: 1

Views: 114

Answers (1)

abhishek
abhishek

Reputation: 88

The error is in the energy function where you calculate eigenvalues and eigenvectors.

tf.linalg.eigh

Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for i=0...N-1.

This is from official docs. Your input to the energy function does not satisfy these conditions.

The issue can be fixed by passing the correct parameters to the energy function

def energy_loss(y_actual,y_predicted):
    mtm_actual = tf.linalg.matmul(y_actual,tf.transpose(y_actual))
    ptp_actual = tf.linalg.matmul(y_predicted,tf.transpose(y_predicted))
    actual_energy = energy(mtm_actual)
    predicted_energy = energy(ptp_actual)
    return tf.math.abs(actual_energy - predicted_energy)

Upvotes: 1

Related Questions