Reputation: 711
My question is : How can I predict a label of such image with Tensorflow Federated ?
After completing the evaluation of the model, I would like to predict the label of a given image. Like in Keras we do this :
# new instance where we do not know the answer
Xnew = array([[0.89337759, 0.65864154]])
# make a prediction
ynew = model.predict_classes(Xnew)
# show the inputs and predicted outputs
print("X=%s, Predicted=%s" % (Xnew[0], ynew[0]))
Output:
X=[0.89337759 0.65864154], Predicted=[0]
here is how state and model_fn was created:
def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),client_weight_fn=None)
state = iterative_process.initialize()
I find this error :
list(self._name_to_index.keys())[:10]))
AttributeError: The tuple of length 2 does not have named field "assign_weights_to". Fields (up to first 10): ['trainable', 'non_trainable']
Thanks
Upvotes: 3
Views: 715
Reputation: 2941
(Requires TFF 0.16.0
or newer)
Since the code is building a tff.learning.Model
from a tf.keras.Model
you may be able to use the assign_weights_to
method on the tff.learning.ModelWeights
object (the type of state.model
).
This method is used in the Federated Learning for Text Generation tutorial.
This might look something like (near the bottom, the early portions are an example FL training loop):
def create_keras_model() -> tf.keras.Model:
...
def model_fn():
...
return tff.learning.from_keras_model(create_keras_model())
training_process = tff.learning. build_federated_averaging_process(model_fn, ...)
state = training_process.initialize()
for _ in range(NUM_ROUNDS):
state, metrics = training_process.next(state, ...)
model_for_inference = create_keras_model()
state.model.assign_weights_to(model_for_inference)
Once the weights from state
have been assigned back into the Keras model, the code can use the standard Keras APIs, such as tf.keras.Model.predict_on_batch
predictions = model_for_inference.predict_on_batch(batch)
Upvotes: 4