machinery
machinery

Reputation: 6290

Keras Dense models with shared weights

I would like to create a Dense model in keras with shared weights. My approach is described below. I have one Dense network called dense. dense_x, dense_y and dense_z should share weights (i.e. use dense). The output of these three networks is then concatenated and feed in another dense network.

However, somehow this approach does not work. How can I use shared weights in the correct way?

num_nodes = 310
input_tensor_x = Input(shape=(310,))
input_tensor_y = Input(shape=(310,))
input_tensor_z = Input(shape=(310,))

dense = Dense(num_nodes)
dense = BatchNormalization()(dense)
dense = Activation('relu')(dense)
dense = Dropout(0.4)(dense)

dense = Dense(num_nodes // 2)
dense = BatchNormalization()(dense)
dense = Activation('relu')(dense)
dense = Dropout(0.4)(dense)

dense = Dense(num_nodes // 4)
dense = BatchNormalization()(dense)
dense = Activation('relu')(dense)
dense = Dropout(0.4)(dense)

dense_x = dense(input_tensor_x) #<=== shared above dense network
dense_y = dense(input_tensor_y) #<=== shared above dense network
dense_z = dense(input_tensor_z) #<=== shared above dense network
merge_layer = tf.keras.layers.Concatenate()([dense_x, dense_y, dense_z])

merged_nodes = 3*(num_nodes // 4) // 2
dense2 = Dense(merged_nodes)(merge_layer)
dense2 = BatchNormalization()(dense2)
dense2 = Activation('relu')(dense2)
dense2 = Dropout(0.4)(dense2)

dense2 = Dense(merged_nodes // 2)(dense2)
dense2 = BatchNormalization()(dense2)
dense2 = Activation('relu')(dense2)
dense2 = Dropout(0.4)(dense2)

output_tensor = Dense(3, activation='softmax')(dense2)

fcnn_model = Model(inputs=[input_tensor_x, input_tensor_y, input_tensor_z], outputs=output_tensor)
opt = Adam(lr=learning_rate)
fcnn_model.compile(loss='categorical_crossentropy',
                  optimizer=opt, metrics=['accuracy', tf.keras.metrics.AUC()])

Upvotes: 1

Views: 103

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

the simplest way is to initialize the shared layers at the top and then pass the following layers into them. pay attention to not overwrite them.

I initialized a Dense layer and a BatchNormalization because they are the only layers with trainable weights

num_nodes = 310
input_tensor_x = Input(shape=(310,))
input_tensor_y = Input(shape=(310,))
input_tensor_z = Input(shape=(310,))

def shared_dense(inp):
    
    dense = Dense(num_nodes)(inp)
    dense = BatchNormalization()(dense)
    dense = Activation('relu')(dense)
    dense = Dropout(0.4)(dense)

    dense = Dense(num_nodes // 2)(dense)
    dense = BatchNormalization()(dense)
    dense = Activation('relu')(dense)
    dense = Dropout(0.4)(dense)

    dense = Dense(num_nodes // 4)(dense)
    dense = BatchNormalization()(dense)
    dense = Activation('relu')(dense)
    dense = Dropout(0.4)(dense)

    return Model(inp, dense, name='shared_dense')

shared_dense = shared_dense(Input(shape=(310,)))
dense_x = shared_dense(input_tensor_x) #<=== shared 
dense_y = shared_dense(input_tensor_y) #<=== shared 
dense_z = shared_dense(input_tensor_z) #<=== shared 

merge_layer = Concatenate()([dense_x, dense_y, dense_z])

merged_nodes = 3*(num_nodes // 4) // 2
dense2 = Dense(merged_nodes)(merge_layer)
dense2 = BatchNormalization()(dense2)
dense2 = Activation('relu')(dense2)
dense2 = Dropout(0.4)(dense2)

dense2 = Dense(merged_nodes // 2)(dense2)
dense2 = BatchNormalization()(dense2)
dense2 = Activation('relu')(dense2)
dense2 = Dropout(0.4)(dense2)

output_tensor = Dense(3, activation='softmax')(dense2)

fcnn_model = Model(inputs=[input_tensor_x, input_tensor_y, input_tensor_z], 
                   outputs=output_tensor)
fcnn_model.compile(loss='categorical_crossentropy',
                  optimizer=Adam(lr=1e-3), 
                   metrics=['accuracy', tf.keras.metrics.AUC()])

Upvotes: 1

Related Questions