Adrian Grygutis
Adrian Grygutis

Reputation: 479

Convolution and classification model in main model

I have to create model of neural network, like this:

convolution --> classification
       \            /
        \          /
        _\|      |/_
         third model
       with one output

Convolution outputs data, which is used as input to classification model. After that, convolution and classification outputs are filled (concatenate) to third model. Third model will output prediction 0..1, which is used to train whole network.

Full log of error: "Graph disconnected: cannot obtain value for tensor Tensor("classification_prediction_Input_2:0", shape=(1, 512), dtype=float32) at layer "classification_prediction_Input". The following previous layers were accessed without issue: []".

If idea is correct, how to connect models like on "graphic"?

My code at now:

# state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input', batch_shape=(1, 210, 160, 3))
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
state_convolution_model = Model(state_input, state_outputs, name='state_convolution_model')
state_convolution_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['acc'])

state_convolution_model_input = Input(shape=INPUT_SHAPE, name='state_convolution_model_input', batch_shape=(1, 210, 160, 3))
state_convolution = state_convolution_model(state_convolution_model_input)

# classification output
classficication_Input = Input(shape=(1, LSTM_OUTPUT_DIM), batch_shape=(1, LSTM_OUTPUT_DIM), name='classification_prediction_Input')
classficication_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(classficication_Input)
classficication_output_raw = Dense(ACTIONS, activation='sigmoid', name='classification_output_raw')(classficication_Dense_1)
classficication_output = Reshape((ACTIONS,), name='classification_output')(classficication_output_raw)
classficication_model = Model(classficication_Input, classficication_output, name='classificationPrediction_model')
classficication_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

classficicationPrediction = classficication_model(state_convolution)

i = keras.layers.concatenate([state_outputs, classficication_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- graph error is here
plot_model(model, to_file='model.png', show_shapes=True)

Upvotes: 0

Views: 52

Answers (1)

Dmytro Prylipko
Dmytro Prylipko

Reputation: 5064

Yes, you can build a structure like this and train it in end-to-end fashion. However, you need to create a single model that has several branches. Another problem I can see is that you compile model before it is fully defined. Here is working code:

# state convolution                                                                                                                                                                                                                                                   
state_input = Input(shape=INPUT_SHAPE, name='state_input')
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)

# classification output                                                                                                                                                                                                                                               
classification_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(state_outputs)
classification_output_raw = Dense(ACTIONS,                                                                                                                                                                                                                            
                                  activation='sigmoid',                                                                                                                                                                                                               
                                  name='classification_output_raw')(classification_Dense_1)
classification_output = Reshape((ACTIONS,), name='classification_output')(classification_output_raw)


i = concatenate([state_outputs, classification_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- no graph error anymore here                                                                                                                                                                                      
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()

Output:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
state_input (InputLayer)        (None, 210, 160, 3)  0                                            
__________________________________________________________________________________________________
state_Conv2D_1 (Conv2D)         (None, 51, 39, 8)    1544        state_input[0][0]                
__________________________________________________________________________________________________
state_MaxPooling2D_1 (MaxPoolin (None, 25, 19, 8)    0           state_Conv2D_1[0][0]             
__________________________________________________________________________________________________
state_Flatten (Flatten)         (None, 3800)         0           state_MaxPooling2D_1[0][0]       
__________________________________________________________________________________________________
classification_prediction_Dense (None, 32)           121632      state_Flatten[0][0]              
__________________________________________________________________________________________________
classification_output_raw (Dens (None, 4)            132         classification_prediction_Dense_1
__________________________________________________________________________________________________
classification_output (Reshape) (None, 4)            0           classification_output_raw[0][0]  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 3804)         0           state_Flatten[0][0]              
                                                                 classification_output[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           121760      concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            33          dense[0][0]                      
==================================================================================================

See Guide to the Functional API for more examples.

Upvotes: 0

Related Questions