seni
seni

Reputation: 711

How to make prediction with TFF?

My question is : How can I predict a label of such image with Tensorflow Federated ?

After completing the evaluation of the model, I would like to predict the label of a given image. Like in Keras we do this :

# new instance where we do not know the answer
Xnew = array([[0.89337759, 0.65864154]])
# make a prediction
ynew = model.predict_classes(Xnew)
# show the inputs and predicted outputs
print("X=%s, Predicted=%s" % (Xnew[0], ynew[0]))

Output:

X=[0.89337759 0.65864154], Predicted=[0]

here is how state and model_fn was created:


def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch) 

iterative_process = tff.learning.build_federated_averaging_process(model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),client_weight_fn=None)
state = iterative_process.initialize()

I find this error :

list(self._name_to_index.keys())[:10]))
AttributeError: The tuple of length 2 does not have named field "assign_weights_to". Fields (up to first 10): ['trainable', 'non_trainable']

Thanks

Upvotes: 3

Views: 715

Answers (1)

Zachary Garrett
Zachary Garrett

Reputation: 2941

(Requires TFF 0.16.0 or newer)

Since the code is building a tff.learning.Model from a tf.keras.Model you may be able to use the assign_weights_to method on the tff.learning.ModelWeights object (the type of state.model). This method is used in the Federated Learning for Text Generation tutorial.

This might look something like (near the bottom, the early portions are an example FL training loop):


def create_keras_model() -> tf.keras.Model:
  ...

def model_fn():
  ...
  return tff.learning.from_keras_model(create_keras_model())

training_process = tff.learning. build_federated_averaging_process(model_fn, ...)

state = training_process.initialize()
for _ in range(NUM_ROUNDS):
  state, metrics = training_process.next(state, ...)

model_for_inference = create_keras_model()
state.model.assign_weights_to(model_for_inference)

Once the weights from state have been assigned back into the Keras model, the code can use the standard Keras APIs, such as tf.keras.Model.predict_on_batch

predictions = model_for_inference.predict_on_batch(batch)

Upvotes: 4

Related Questions