Michael24
Michael24

Reputation: 59

How can I save a trained TensorFlow Federated model as a .h5 model?

I want to save a TensorFlow federated model which was trained with the FedAvg Algorithm as a Keras/.h5 model. I couldn't find the documents on this and would like to know how it may be done. Also if possible, I'd like to have access to both the aggregated server model and the models of the clients.

The code I use to train the federated model is below:

def model_fn():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
      tf.keras.layers.Flatten(), 
      tf.keras.layers.Dense(units=400, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(units=100, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(activityCount, activation='softmax'),
    ])
    return tff.learning.from_keras_model(
      model,
      dummy_batch=batch,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))

def evaluate(num_rounds=communicationRound):
  state = trainer.initialize()
  roundMetrics = []
  evaluation = tff.learning.build_federated_evaluation(model_fn)

  for round_num in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    test_metrics = evaluation(state.model, train_data)

    roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    roundMetrics.append('round time={}'.format(t2 - t1))
    print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    print('round time={}'.format(t2 - t1))
  outF = open(filepath+'stats'+architectureType+'.txt', "w")
  for line in roundMetrics:
    outF.write(line)
    outF.write("\n")
  outF.close()

Upvotes: 0

Views: 408

Answers (1)

Eliza
Eliza

Reputation: 584

Roughly, we will be using save_checkpoint/load_checkpoint methods. In particular, you can instantiate a FileCheckpointManager, and ask it to save state (almost) directly.

state in your example is an instance of tff.python.common_libs.anonymous_tuple.AnonymousTuple (IIRC), which is not compatible with tf.convert_to_tensor, as is needed by save_checkpoint and declared in its docstring. The general solution often used in TFF research code is to introduce a Python attrs class to convert away from the anonymous tuple as soon as state is returned--

Assuming the above, the following sketch should work:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

And to restore from this checkpoint, at any time you can call:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

One thing to note: the code pointers linked above are generally in tff.python.research..., which is not included in the pip package; so the preferred way to get at them is to either fork the code into your own project, or pull down the repo and build it from source.

Upvotes: 3

Related Questions