Arian Yambao
Arian Yambao

Reputation: 13

Tensorflow: multiple one-hot outputs with one input?

Is there a way to feed a Feed Forward tensorflow model with the labels containing multiple one hot encoding in one input?

For example:

x[0]: [15, 32, 3], y[0]: [[0,0,1], [0,1,0], [1,0,0]]

x[1]: [23, 2, 7], y[1]: [[1,0,0], [0,0,1], [0,1,0]]

The shape of x will be (1,3) while y will have (3, 3)

Is there a way to train a data like this in a simple feed forward neural network?

Upvotes: 0

Views: 939

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36624

You could do that with a custom training loop and a model subclassed from Keras.Model. You can then have 3 output with 3 losses (potentially even a mix of categorical and continuous).

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from functools import partial
import numpy as np

X = np.random.randint(0, 100, (100, 3)).astype(np.float32)
w = np.random.randint(0, 3, 100)
y = np.random.randint(0, 3, 100)
z = np.random.randint(0, 3, 100)

onehot = partial(tf.one_hot, depth=3)

dataset = tf.data.Dataset.from_tensor_slices((X, w, y, z)).\
    shuffle(100).\
    batch(4).\
    map(lambda a, b, c, d: (a, onehot(b), onehot(c), onehot(d)))


print(next(iter(dataset)))


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d0 = Dense(16, activation='relu')
        self.d1 = Dense(32, activation='relu')
        self.d2 = Dense(3)
        self.d3 = Dense(3)
        self.d4 = Dense(3)

    def call(self, x, training=None, **kwargs):
        x = self.d0(x)
        x = self.d1(x)
        out1 = self.d2(x)
        out2 = self.d3(x)
        out3 = self.d4(x)
        return out1, out2, out3

model = MyModel()

loss_object = tf.losses.categorical_crossentropy

def compute_loss(model, x, w, y, z, training):
  out1, out2, out3 = model(inputs=x, training=training)
  loss1 = loss_object(y_true=w, y_pred=out1, from_logits=True)
  loss2 = loss_object(y_true=y, y_pred=out2, from_logits=True)
  loss3 = loss_object(y_true=z, y_pred=out3, from_logits=True)
  return loss1, loss2, loss3

def get_grad(model, x, w, y, z):
    with tf.GradientTape() as tape:
        loss1, loss2, loss3 = compute_loss(model, x, w, y, z, training=False)
    gradients = tape.gradient([loss1, loss2, loss3], model.trainable_variables)
    return (loss1, loss2, loss3), gradients

optimizer = tf.optimizers.Adam()

verbose = "Epoch {:2d} Loss1: {:.3f} Loss2: {:.3f} Loss3: {:.3f}"

for epoch in range(1, 10 + 1):
    loss1 = tf.metrics.Mean()
    loss2 = tf.metrics.Mean()
    loss3 = tf.metrics.Mean()

    for X, w, y, z in dataset:
        losses, grads = get_grad(model, X, w, y, z)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        for current_loss, running_loss in zip(losses, [loss1, loss2, loss3]):
            running_loss.update_state(current_loss)

    print(verbose.format(epoch,
                         loss1.result(),
                         loss2.result(),
                         loss3.result()))

Input:

(<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[45., 35., 46.],
       [64., 95., 55.],
       [90., 41., 12.],
       [98., 17., 81.]], dtype=float32)>, <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.]], dtype=float32)>, <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.]], dtype=float32)>, <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.]], dtype=float32)>)

Output:

Epoch  1 Loss1: 0.980 Loss2: 1.051 Loss3: 1.035
Epoch  2 Loss1: 0.886 Loss2: 0.961 Loss3: 0.934
Epoch  3 Loss1: 0.814 Loss2: 0.872 Loss3: 0.847
Epoch  4 Loss1: 0.762 Loss2: 0.804 Loss3: 0.786
Epoch  5 Loss1: 0.732 Loss2: 0.755 Loss3: 0.747
Epoch  6 Loss1: 0.718 Loss2: 0.731 Loss3: 0.723
Epoch  7 Loss1: 0.709 Loss2: 0.715 Loss3: 0.714
Epoch  8 Loss1: 0.700 Loss2: 0.708 Loss3: 0.705
Epoch  9 Loss1: 0.699 Loss2: 0.697 Loss3: 0.701
Epoch 10 Loss1: 0.703 Loss2: 0.702 Loss3: 0.698

Upvotes: 1

Related Questions