Etenedrac
Etenedrac

Reputation: 101

How to add post-processing into a Tensorflow Model?

I am trying to save a TensorFlow model which includes some post-procesing for the labels.

Given some categorical labels, I am interested into training a model (for instance, a tf.keras.Sequential), in which I have previously applied a One-hot encoding to the labels. This is how the model would look like:

model = tf.keras.Sequential([
    tf.keras.layers.DenseFeatures(transform_features),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

model.compile(loss='categorical_crossentropy',optimizer='adam')

history = model.fit(train_data, epochs=5)

Where transform_features is a list of tf.feature_columns, and train_data is a tf.data.Dataset that contains the training data (train_X,train_y).

Once trained the model, I would like to apply some post processing. I would like to add this postprocessing inside a new (or the same) TensorFlow model, so that when I ask for predictions to this model (making predictions with imported TensorFlow models in BigQuery for instance), it gives me the decoded final label.

I was thinking in making a first model as the one shown previously, and after training it, add a to the model the following layer:

from tf.keras.layers import Lambda
model.add(Lambda(lambda x: tf.argmax(x, axis=-1)))

But I don't know how could I "merge" this two different models and save them into the same TensorFlow SavedModel format (using tf.saved_model.save(model, MODEL_PATH)). Is there any way in which one cloud do this post-processing in Tensorflow?

Thanks

Upvotes: 9

Views: 2642

Answers (1)

Reda El Hail
Reda El Hail

Reputation: 1016

Tensorflow provides a method to build custom layers which run custom functions called Lambda layers. To see an example of argmax layer, see this answer.

However, another way is to use a subclass layer keras.layers.Layer which is more advanced and gives more flexbility.

Examples of subclass layers:

  • Scaling layer:

Multipies by a scaling factor

 class ScaleLayer(tf.keras.layers.Layer):
    def __init__(self):
      super(ScaleLayer, self).__init__()
      self.scale = tf.Variable(1.)

    def call(self, inputs):
      return inputs * self.scale
  • Argmax layer:

Retrives index of highest value.

class argmax_layer(Layer):
  def __init__(self):
    super(argmax_layer, self).__init__()

  def call(self, inputs):
    return tf.math.argmax(inputs, axis=1)

This is the full code a CNN architecture for a classification task on which I add an argmaxLayer.

import tensorflow as tf
from keras.models import Sequential
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPooling2D,Flatten,Dropout,Dense
from tensorflow.keras  import optimizers
import numpy as np


class argmax_layer(Layer):
  def __init__(self):
    super(argmax_layer, self).__init__()

  def call(self, inputs):
    return tf.math.argmax(inputs, axis=1)


def cnn_model(image_x=100,image_y=100,num_classes=10):

    model = Sequential()
    model.add(Conv2D(32, (5,5), input_shape=(image_x, image_y, 1), activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(10, 10), strides=(10, 10), padding='same'))
    model.add(Flatten())
    model.add(Dense(1024, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.6))
    model.add(Dense(num_classes, activation='softmax'))
    sgd = optimizers.SGD(lr=1e-2)
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model


########################
# Model summary
model = cnn_model()
model.add(argmax_layer())
model.summary()

#################### TEST
input = np.random.random((5,100,100,1)) # 5 samples
print("Output:", model.predict(input))

Model summary:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 96, 96, 32)        832
_________________________________________________________________
batch_normalization (BatchNo (None, 96, 96, 32)        128
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 10, 10, 32)        0
_________________________________________________________________
flatten (Flatten)            (None, 3200)              0
_________________________________________________________________
dense (Dense)                (None, 1024)              3277824
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024)              4096
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                10250
_________________________________________________________________
argmax_layer (argmax_layer)  (None,)                   0
=================================================================
Total params: 3,293,130
Trainable params: 3,291,018
Non-trainable params: 2,112
________________________________

Upvotes: 1

Related Questions