Reputation: 101
I am trying to save a TensorFlow model which includes some post-procesing for the labels.
Given some categorical labels, I am interested into training a model (for instance, a tf.keras.Sequential
), in which I have previously applied a One-hot encoding to the labels. This is how the model would look like:
model = tf.keras.Sequential([
tf.keras.layers.DenseFeatures(transform_features),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(loss='categorical_crossentropy',optimizer='adam')
history = model.fit(train_data, epochs=5)
Where transform_features
is a list of tf.feature_columns
, and train_data
is a tf.data.Dataset
that contains the training data (train_X,train_y)
.
Once trained the model, I would like to apply some post processing. I would like to add this postprocessing inside a new (or the same) TensorFlow model, so that when I ask for predictions to this model (making predictions with imported TensorFlow models in BigQuery for instance), it gives me the decoded final label.
I was thinking in making a first model as the one shown previously, and after training it, add a to the model the following layer:
from tf.keras.layers import Lambda
model.add(Lambda(lambda x: tf.argmax(x, axis=-1)))
But I don't know how could I "merge" this two different models and save them into the same TensorFlow SavedModel format (using tf.saved_model.save(model, MODEL_PATH)
). Is there any way in which one cloud do this post-processing in Tensorflow?
Thanks
Upvotes: 9
Views: 2642
Reputation: 1016
Tensorflow provides a method to build custom layers which run custom functions called Lambda layers. To see an example of argmax
layer, see this answer.
However, another way is to use a subclass layer keras.layers.Layer
which is more advanced and gives more flexbility.
Examples of subclass layers:
Multipies by a scaling factor
class ScaleLayer(tf.keras.layers.Layer):
def __init__(self):
super(ScaleLayer, self).__init__()
self.scale = tf.Variable(1.)
def call(self, inputs):
return inputs * self.scale
Retrives index of highest value.
class argmax_layer(Layer):
def __init__(self):
super(argmax_layer, self).__init__()
def call(self, inputs):
return tf.math.argmax(inputs, axis=1)
This is the full code a CNN architecture for a classification task on which I add
an argmaxLayer
.
import tensorflow as tf
from keras.models import Sequential
from tensorflow.keras.layers import Layer, Conv2D, BatchNormalization, MaxPooling2D,Flatten,Dropout,Dense
from tensorflow.keras import optimizers
import numpy as np
class argmax_layer(Layer):
def __init__(self):
super(argmax_layer, self).__init__()
def call(self, inputs):
return tf.math.argmax(inputs, axis=1)
def cnn_model(image_x=100,image_y=100,num_classes=10):
model = Sequential()
model.add(Conv2D(32, (5,5), input_shape=(image_x, image_y, 1), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(10, 10), strides=(10, 10), padding='same'))
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.6))
model.add(Dense(num_classes, activation='softmax'))
sgd = optimizers.SGD(lr=1e-2)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return model
########################
# Model summary
model = cnn_model()
model.add(argmax_layer())
model.summary()
#################### TEST
input = np.random.random((5,100,100,1)) # 5 samples
print("Output:", model.predict(input))
Model summary:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 96, 96, 32) 832
_________________________________________________________________
batch_normalization (BatchNo (None, 96, 96, 32) 128
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 10, 10, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 3200) 0
_________________________________________________________________
dense (Dense) (None, 1024) 3277824
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024) 4096
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 10250
_________________________________________________________________
argmax_layer (argmax_layer) (None,) 0
=================================================================
Total params: 3,293,130
Trainable params: 3,291,018
Non-trainable params: 2,112
________________________________
Upvotes: 1