Aziz  Siyaev
Aziz Siyaev

Reputation: 599

How to make a custom loss function that uses previous output from network in Keras?

I am trying to build a custom loss function that takes the previous output(output from the previous iteration) from the network and use it with the current output.

Here is what I am trying to do, but I don't know how to complete it

def l_loss(prev_output):

    def loss(y_true, y_pred):

        pix_loss = K.mean(K.square(y_pred - y_true), axis=-1)

        pase = K.variable(100)

        diff = K.mean(K.abs(prev_output - y_pred))
        movement_loss = K.abs(pase - diff)
        total_loss = pix_loss + movement_loss

        return total_loss
    return loss

self.model.compile(optimizer=Adam(0.001, beta_1=0.5, beta_2=0.9),
 loss=l_loss(?))

I hope you can help me.

Upvotes: 1

Views: 524

Answers (1)

Pedro Marques
Pedro Marques

Reputation: 2682

This is what I tried:

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K

class MovementLoss(object):
  def __init__(self):
    self.var = None

  def __call__(self, y_true, y_pred, sample_weight=None):
    mse = K.mean(K.square(y_true - y_pred), axis=-1)
    if self.var is None:
      z = np.zeros((32,))
      self.var = K.variable(z)
    delta = K.update(self.var, mse - self.var)
    return mse + delta


def make_model():
  model = Sequential()
  model.add(Dense(1, input_shape=(4,)))
  loss = MovementLoss()
  model.compile('adam', loss)
  return model

model = make_model()
model.summary()


Using an example test data.

import numpy as np

X = np.random.rand(32, 4)

POLY = [1.0, 2.0, 0.5, 3.0]
def test_fn(xi):
  return np.dot(xi, POLY)

Y = np.apply_along_axis(test_fn, 1, X)

history = model.fit(X, Y, epochs=4)

I do see the loss function oscillate in a way that appears to me is influenced by the last batch delta. Note that the loss function details are not according to your application.

The crucial step is that the K.update step must be part of the graph (as far as I understand it).

That is achieved by:

delta = K.update(var, delta)
return x + delta

Upvotes: 1

Related Questions