freeideas
freeideas

Reputation: 141

how to merge classification models horizontally

I have many models already trained, which each answer a simple yes/no question. Pseudocode:

model_dog = keras.load('is_dog')
model_cat = keras.load('is_cat')
model_rat = keras.load('is_rat')

image = load_photo_as_numpy_array('photo.jpg')

multi_class = [ m.predict(image) for m in (model_dog,model_cat,model_rat) ]

This works fine, but it is a> slow because inference is done sequentially instead of in parallel (I have several hundred such models, not just 3), and b> is much more complex to use than if I had ONE model which does multi-classification.

What I want, is:

model = keras.concat_horizontal([ model_dog, model_cat, model_rat ])
model.save('combined_model')

Then whenever I want to use the combined model, it is as simple as:

model = keras.load('combined_model')
multi_class = m.predict(image)

This way, I can add a new classification to the combined model, by training one simple model, for example, that recognizes a fish.

Upvotes: 1

Views: 380

Answers (1)

Kaveh
Kaveh

Reputation: 4960

As I suggested in comments, you can merge multiple models in one new model and predict using this new model.

First, I write a function to merge models and return a new combined model. This is what you want:

def concat_horizontal(models, input_shape):
  models_count = len(models)
  hidden = []
  input = tf.keras.layers.Input(shape=input_shape)
  for i in range(models_count):
    hidden.append(models[i](input))
  output = tf.keras.layers.concatenate(hidden)
  model = tf.keras.Model(inputs=input, outputs=output)
  return model

Let's explore an example. Say we want merge two sequential models like this:

def model_1():
  model = tf.keras.models.Sequential([
                      tf.keras.layers.Flatten(input_shape=(28,28,1)),
                      tf.keras.layers.Dense(150, activation='relu'),
                      tf.keras.layers.Dense(200, activation='relu'),
                      tf.keras.layers.Dense(150, activation='relu'),
                      tf.keras.layers.Dense(10, activation='softmax')], name="model1")
  model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
  return model

def model_2():
  model = tf.keras.models.Sequential([
                      tf.keras.layers.Flatten(input_shape=(28,28,1)),
                      tf.keras.layers.Dense(150, activation='relu'),
                      tf.keras.layers.Dense(150, activation='relu'),
                      tf.keras.layers.Dense(10, activation='softmax')], name="model2")
  model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
  return model

model1 = model_1()
model2 = model_2()

Let's use MNIST as train dataset for both of our models:

import tensorflow_datasets as tfds
ds_1 = tfds.load('mnist', split='train', as_supervised=True)
ds_2 = tfds.load('mnist', split='test', as_supervised=True)

def map_fn(image, label):
  image = image / 255
  return image, label

ds_1 = ds_1.map(map_fn).shuffle(1024).batch(32)
ds_2 = ds_2.map(map_fn).shuffle(1024).batch(32)

Now, we can train models, save them, and then load them like this:

model1.fit(ds_1, epochs=2, validation_data=ds_1)
model2.fit(ds_2, epochs=2, validation_data=ds_2)

model1.save('model1.h5')
model2.save('model2.h5')

model3 = tf.keras.models.load_model('model1.h5')
model4 = tf.keras.models.load_model('model2.h5')

So we have 2 separate models (model3,model4) and want to merge these, to a new one. Pass them along the input shape (in this case it is MNIST data shape) to the function we have written above:

new_model = concat_horizontal([model3,model4],(28,28,1))

Now, if we plot this new model:

tf.keras.utils.plot_model(new_model)

new_model_plot

It's time to get predictions of model:

sample = ds_1.unbatch().take(1)
for i,j in sample:
  img = i
  lbl = j
img = tf.expand_dims(img,axis=0)
pred = new_model.predict(img)
pred = np.reshape(pred,(2,10))
results = np.argmax(pred,axis=1)
print(results)

import matplotlib.pyplot as plt
plt.imshow(np.array(img).squeeze())
plt.show

In my case I get both of predictions classified as 4:

Output:

predictions

Upvotes: 1

Related Questions