deepLearner
deepLearner

Reputation: 45

Extract features from 2 auto-encoders and feed them into an MLP

I understand that the features extracted from an auto-encoder can be fed into an mlp for classification or regression purpose. This is something that I did earlier.
But what if I have 2 auto-encoders? Can I extract the features from the bottleneck layers of 2 auto-encoders and feed them into an mlp which performs classification based on these features? If yes, then how? I am not sure how to concatenate these two feature sets. I tried with numpy.hstack() which gives me 'unhashable slice' error, whereas, using tf.concat() gives me the error 'Input tensors to a Model must be Keras tensors.' the bottleneck layers of the two auto-encoders are of dimension (None,100) each. So, essentially, if I stack them horizontally, I should be getting a (None, 200). The hidden layer of the mlp may contain some (num_hidden=100) neurons. Could anyone please help?

x1 = autoencoder1.get_layer('encoder2').output
x2 = autoencoder2.get_layer('encoder2').output

#inp = np.hstack((x1, x2))
inp = tf.concat([x1, x2], 1)
x = tf.concat([x1, x2], 1)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
y = Dense(1, activation='sigmoid', name='prediction')(h)
mymlp = Model(inputs=inp, outputs=y)

# Compile model
mymlp.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train model
mymlp.fit(x_train, y_train, epochs=20, batch_size=8)

updated as per @twolffpiggott's suggestion:

from keras.layers import Input, Dense, Dropout
from keras import layers
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np

x1 = Data1
x2 = Data2
y = Data3

num_neurons1 = x1.shape[1]
num_neurons2 = x2.shape[1]

# Train-test split
x1_train, x1_test, x2_train, x2_test, y_train, y_test = train_test_split(x1, x2, y, test_size=0.2)

# scale data within [0-1] range
scalar = MinMaxScaler()
x1_train = scalar.fit_transform(x1_train)
x1_test = scalar.transform(x1_test)

x2_train = scalar.fit_transform(x2_train)
x2_test = scalar.transform(x2_test)

x_train = np.concatenate([x1_train, x2_train], axis =-1)
x_test = np.concatenate([x1_test, x2_test], axis =-1)

# Auto-encoder1

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons1,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded1)
decoded = Dense(num_neurons1, activation='sigmoid', name='decoder2')(decoded)

# this model maps an input to its reconstruction
autoencoder1 = Model(inputs=input_data, outputs=decoded)
autoencoder1.compile(optimizer='sgd', loss='mse')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

# training
autoencoder1.fit(x1_train, x1_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x1_test, x1_test))

# Auto-encoder2

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons2,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded2 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded2)
decoded = Dense(num_neurons2, activation='sigmoid', name='decoder2')(decoded)


# this model maps an input to its reconstruction
autoencoder2 = Model(inputs=input_data, outputs=decoded)
autoencoder2.compile(optimizer='sgd', loss='mse')

# training
autoencoder2.fit(x2_train, x2_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x2_test, x2_test))

# MLP

num_hidden = 100

encoded1.trainable = False
encoded2.trainable = False

encoded1 = autoencoder1(autoencoder1.inputs)
encoded2 = autoencoder2(autoencoder2.inputs)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
x = Dropout(0.2)(concatenated)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

# Compile model
myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit(x_train, y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict(x_test)

giving me an error: unhashable type: 'list' from the line: myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

Upvotes: 1

Views: 1095

Answers (2)

Daniel Möller
Daniel Möller

Reputation: 86600

The problem is that you're mixing numpy arrays with keras tensors. This can't go.

There are two approaches.

  • Predict numpy arrays from each autoencoder, concat the arrays, send them to the third model
  • Connect all models, probably make the autoencoders untrainable, fit with one input for each autoencoder.

Personally, I'd go for the first. (Assuming the autoencoders are already trained and don't need change).

First approach

numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1)    
numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2)

inputDataForThird = np.concatenate([numpyOutputFromAuto1,numpyOutputFromAuto2],axis=-1)

inputTensorForMlp = Input(inputsForThird.shape[1:])
h = Dense(num_hidden, activation='relu', name='hidden')(inputTensorForMlp)
y = Dense(1, activation='sigmoid', name='prediction')(h)

mymlp = Model(inputs=inputTensorForMlp, outputs=y)

....
mymlp.fit(inputDataForThird ,someY)

Second Approach

This is a little more complicated, and at first I don't see much reason to do this. (But of course there may be cases where it's a good choice)

Now we're totally forgetting numpy and working with keras tensors.

Creating the mlp on its own (good if you will use it later without the autoencoders):

inputTensorForMlp = Input(input_shape_compatible_with_concatenated_encoder_outputs)
x = Dropout(0.2)(inputTensorForMlp)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

We probably want the bottleneck features of the autoencoders, right? If you happened to create the autoencoders properly with: encoder model, decoder model, join both, then it's easier to use just the encoder model. Else:

encodedOutput1 = autoencoder1.layers[bottleneckLayer].outputs #or encoder1.outputs
encodedOutput2 = autoencoder1.layers[bottleneckLayer].outputs #or encoder2.outputs

Creating a joined model. The concatenation must use a keras layer (we're working with keras tensors):

concatenated = Concatenate()([encodedOutput1,encodedOutput2])
output = myMLP(concatenated)

joinedModel = Model([autoencoder1.input,autoencoder2.input],output)

Upvotes: 3

twolffpiggott
twolffpiggott

Reputation: 1103

I'd also go with Daniel's first approach (for simplicity and efficiency), but if you're interested in the second; for instance if you're interested in running the network end-to-end, you'd approach it like this:

# make autoencoders not trainable
autoencoder1.trainable = False
autoencoder2.trainable = False

encoded1 = autoencoder1(kerasInputs1)
encoded2 = autoencoder2(kerasInputs2)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
h = Dense(num_hidden, activation='relu', name='hidden')(concatenated)
y = Dense(1, activation='sigmoid', name='prediction')(h)

myMLP = Model([input_data1, input_data2], y)

myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit([x1_train, x2_train], y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict([x1_test, x2_test])

Key edits

  1. The weights of both autoencoders should be frozen end-to-end (otherwise early-stage gradient updates from the randomly initialized MLP will likely result in the loss of much of their learning).
  2. The autoencoder input layers should be assigned to separate variables input_data1 and input_data2 per autoencoder (instead of both to input_data). Even though autoencoder1.inputs returns a tf tensor, this is the source of the unhashable type: list exception, and replacing with [input_data1, input_data2] solves the issue.
  3. When fitting the MLP for the end-to-end model, the input should be a list of x1_train and x2_train rather than the concatenated inputs. Same when predicting.

Upvotes: 0

Related Questions