Tim
Tim

Reputation: 375

Transfer Learning on MNIST: wrong labels error

So I trained a perceptron in tensorflow on the MNIST dataset but only the digits 0 through 4. Then I made a new model with all the same layers and weights but with a new output layer also with 5 output nodes. I want to train this new model to classify the digits 5 to 9.

I generated a new x_train and y_train with only the digits 5 to 9, and ran

transfer_model.fit(x_train[train_filter],y_train[train_filter], epoch=5)

where train_filter is defined as np.where(np.logical_and(x_train<=5,x_train>=9)).

At the very first step of training, I get this error:

InvalidArgumentError: Received a label value of 9 which is outside the valid range of [0, 5). Label values: 5 9 7 8 9 8 7 6 8 7 6 9 5 5 8 7 6 9 9 7 6 7 6 8 7 7 9 7 6 8 5 6

This makes sense because I originally trained the network to classify in the range [0,5), but now I want to do the range [5,10). Did I miss a step here? I'm not sure what I'm missing... How do I define what each output neuron corresponds to?

Here is my model summary:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_7 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_49 (Dense)             (None, 100)               78500     
_________________________________________________________________
batch_normalization_10 (Batc (None, 100)               400       
_________________________________________________________________
dropout_5 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_50 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_11 (Batc (None, 100)               400       
_________________________________________________________________
dropout_6 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_51 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_12 (Batc (None, 100)               400       
_________________________________________________________________
dropout_7 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_52 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_13 (Batc (None, 100)               400       
_________________________________________________________________
dropout_8 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_53 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_14 (Batc (None, 100)               400       
_________________________________________________________________
dropout_9 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_55 (Dense)             (None, 5)                 505       
=================================================================
Total params: 121,405
Trainable params: 505
Non-trainable params: 120,900
_________________________________________________________________

Upvotes: 0

Views: 356

Answers (2)

Since you use numpy, you can try the following

import tensorflow as tf
import numpy as np

arr = np.array([5,6,7,8,9,8,7,6,5])
arr = tf.one_hot(arr,10,axis=0).numpy()
arr = arr[5:]

tf.argmax(arr).numpy() # returns array([0, 1, 2, 3, 4, 3, 2, 1, 0])

or using tf.map_fn

arr = np.array([5,6,7,8,9,8,7,6,5])

tf.map_fn(lambda x : x-5, arr).numpy() # array([0, 1, 2, 3, 4, 3, 2, 1, 0])

Upvotes: 0

BernieFeynman
BernieFeynman

Reputation: 121

You need to map 5-9 to 0-4. Class labels are probably done via one hot encoding, you have 5 unique labels, so it only needs a vector of length 5 to represent it. But since label is 5-9 its going to be out of range. You do not need to adjust model, just add a map to the label outputs.

Upvotes: 1

Related Questions