gokulan vikash
gokulan vikash

Reputation: 43

How to transfer weights from baseline model to federated model?

def create_keras_model():
model = Sequential([
    Conv2D(16, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(32, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(64, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Flatten(),
    Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])

model.load_weights('/content/drive/My Drive/localmodel/weights')
return model

Tried something like this in Colab, but I get errno 21, is a directory.

Then I tried another method as shown below,

tff_model = create_keras_model() #now this function doesnt load weights, just returns a Sequential model   
tff.learning.assign_weights_to_keras_model(tff_model, model_with_weights)

Just like assign_weights_to_keras_model() transfers weights from tff_model to keras model, I want to transfer weights from keras model to tff_model. How can this be done?

Upvotes: 1

Views: 442

Answers (2)

gokulan vikash
gokulan vikash

Reputation: 43

I just got to know how this can be done. The idea is to use:

tff.learning.state_with_new_model_weights(state, trainable_weights_numpy, non_trainable_weights_numpy)

Documentation here

where trainable weights are taken from baseline model and converted to numpy format.

trainable_weights = []

for weights in baseline_model.trainable_weights:
    trainable_weights.append(weights.numpy())

This is particularly useful when the server has part of the data and the client has similar data. May be this can be used for transfer learning.

Upvotes: 1

Ayness
Ayness

Reputation: 137

here model_with_weights must be a TFF value representing the weights of a model for example:

def model_fn():

    keras_model = create_keras_model()

  return tff.learning.from_keras_model(keras_model)

fed_avg = tff.learning.build_federated_averaging_process(model_fn, ...)
state = fed_avg.initialize()
state = fed_avg.next(state, ...)
...
tff.learning.assign_weights_to_keras_model(keras_model, state.model)

Upvotes: 3

Related Questions