Ahmed Sayed
Ahmed Sayed

Reputation: 43

Change of the dataset type in the execution stack

The problem is the change of the dataset from one type to another during different points of the execution stack. For example, if I add a new dataset class with more member properties of interest (which inherits from one of the classes in ops.data.dataset_ops like UnaryDataset), the result is at later execution point (client_update function), the dataset is converted to _VaraintDataset Type and hence any added attributes are lost. So the question is how to retain the member attributes of the newly defined dataset class over the course of execution. Below is the emnist example where the type changes from ParallelMapDataset to _VariantDataset.

In the function client_dataset of training_utils.py line 194, I modified it to show the type of the dataset as follows

  def client_datasets(round_num):
    sampled_clients = sample_clients_fn(round_num)
    sampled_client_datasets = []
    for client in sampled_clients:
        d =  train_dataset.create_tf_dataset_for_client(client)
        sampled_client_datasets.append(train_dataset.create_tf_dataset_for_client(client))
        tf.print('CLIENT DATASETS: ', d, type(d))
    return sampled_client_datasets

The output is :

CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
CLIENT DATASETS:  <ParallelMapDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>

Then in the tf.function client_update which is invoked by the clients in the fed_avg_schedule.py line 178, the dataset is of different type

@tf.function
  def client_update(model,
                    dataset,
                    initial_weights,
                    client_optimizer,
                    client_weight_fn=None):
    """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.Model.weights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.

    Returns:
      A 'ClientOutput`.
    """

    tf.print('CLIENT UPDATE: ', dataset, type(dataset))
    ....

The output would be :

CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>
CLIENT UPDATE:  <_VariantDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int32)> <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>

I might be wrong but I have done some tracking and found that at some point the function (_to_components(self, value) of DatasetSpec) is called which does the conversion:

  def _to_components(self, value):
    return value._variant_tensor  # pylint: disable=protected-access

EDIT - following the suggested answer

Below are the changes i have introduced to the simpel_fedavg example after pulling the recent version of the federated repo

First, i add/modified the lines below to build_fed_avg_process of simple_fedavg_tff.py

server_message_type = server_message_fn.type_signature.result
  tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
  meta_data_type = tff.SequenceType(tf.string)

  @tff.tf_computation(tf_dataset_type, meta_data_type, server_message_type)
  def client_update_fn(tf_dataset, meta_data, server_message):
    model = model_fn()
    client_optimizer = client_optimizer_fn()
    return client_update(model, tf_dataset, meta_data, server_message, client_optimizer)

@tff.tf_computation((tf_dataset_type, meta_data_type))
  def extract_data_metadata_fn(tf_dataset_metadata_tuple):
    x, y = tf_dataset_metadata_tuple
    return x, y

  federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER)
  federated_dataset_type = tff.FederatedType( (tf_dataset_type, meta_data_type), tff.CLIENTS)
  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    data_set, meta_data = tff.federated_map(extract_data_metadata_fn, federated_dataset)

    #client_outputs = tff.federated_map(client_update_fn, (federated_dataset, server_message_at_client))
    client_outputs = tff.federated_map(client_update_fn, (data_set, meta_data, server_message_at_client))

In the simple_fedavg_tf.py, I have added the following print line of the meta_data

@tf.function
def client_update(model, dataset, meta_data, server_message, client_optimizer):
  """Performans client local training of `model` on `dataset`.

  Args:
    model: A `tff.learning.Model`.
    dataset: A 'tf.data.Dataset'.
    server_message: A `BroadcastMessage` from server.
    client_optimizer: A `tf.keras.optimizers.Optimizer`.

  Returns:
    A 'ClientOutput`.
  """
  tf.print(meta_data)

  model_weights = model.weights
  initial_weights = server_message.model_weights
  client_ids = server_message.client_ids
  tff.utils.assign(model_weights, initial_weights)

In the main file emnist_simple_fedavg.py, I modifed the following lines of the main training loop in main function:

meta_data = ['a','b','c','d']
server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, meta_data))

Which did not work out and i am getting the following error:

  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 176, in <module>
    app.run(main)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.py", line 166, in main
    server_state, train_metrics = iterative_process.next(server_state, (sampled_train_data, sampled_clients.tolist()))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/utils/function_utils.py", line 563, in __call__
    return context.invoke(self, arg)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
    return Retrying(*dargs, **dkw).call(f, *args, **kw)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
    return attempt.get(self._wrap_exception)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
    six.reraise(self.value[0], self.value[1], self.value[2])
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/six/__init__.py", line 693, in reraise
    raise value
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
    attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 215, in invoke
    _ingest(executor, unwrapped_arg, arg.type_signature)))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 99, in _ingest
    ingested = await asyncio.gather(*ingested)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/execution_context.py", line 104, in _ingest
    return await executor.create_value(val, type_spec)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
    value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
    await cached_value.target_future
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
    self._target_executor.create_value(value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
    result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 383, in create_value
    return await self._strategy.compute_federated_value(value, type_spec)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/federated_resolving_strategy.py", line 275, in compute_federated_value
    for v, c in zip(value, children)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 282, in create_value
    *[self.create_value(val, t) for (_, val), t in zip(v_el, type_spec)])
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 289, in create_value
    value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/caching_executor.py", line 245, in create_value
    await cached_value.target_future
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
    self._target_executor.create_value(value, type_spec))
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
    result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 388, in _wrapped
    return await coro
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py", line 200, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 464, in create_value
    return EagerValue(value, self._tf_function_cache, type_spec, self._device)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 367, in __init__
    type_spec, device)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/core/impl/executors/eager_tf_executor.py", line 335, in to_representation_for_type
    type_conversions.TF_DATASET_REPRESENTATION_TYPES)
  File "/root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
    type_string(type_spec), type_string(type(target))))
TypeError: Expected tensorflow.python.data.ops.dataset_ops.DatasetV2 or tensorflow.python.data.ops.dataset_ops.DatasetV1, found str.
E0721 23:53:29.388700 139706363909952 base_events.py:1285] Task was destroyed but it is pending!
task: <Task pending coro=<trace.<locals>.async_trace() running at /root/.cache/bazel/_bazel_root/13f956c768d751b1bc658674921e5be9/execroot/org_tensorflow_federated/bazel-out/k8-opt/bin/tensorflow_federated/python/examples/simple_fedavg/emnist_fedavg_main.runfiles/org_tensorflow_federated/tensorflow_federated/python/common_libs/tracing.py:200> wait_for=<Future pending cb=[_chain_future.<locals>._call_check_cancel() at /usr/lib/python3.6/asyncio/futures.py:403, <TaskWakeupMethWrapper object at 0x7f0f7c07eca8>()]> cb=[<TaskWakeupMethWrapper object at 0x7f0f7c07e648>()]>

Upvotes: 4

Views: 293

Answers (2)

Jakub Konecny
Jakub Konecny

Reputation: 900

From the updated info and error log, I think the issue is in this part:
iterative_process.next(server_state, (sampled_train_data, meta_data))

What I guess you need is the 2nd argument to next to be an iterable of vaguely (sampled_train_data_element, meta_data_element) tuples - one element per sampled client.

So this may be achieved by changing it to
iterative_process.next(server_state, zip(sampled_train_data, meta_data))
or if that does not work, perhaps this?
iterative_process.next(server_state, list(zip(sampled_train_data, meta_data)))


Also, assuming you wanted the meta_data to be a single string per client, the meta_data_type should be changed to tff.to_type(tf.string). The tff.SequenceType is meant for representing in general sequences, such as datasets.

Upvotes: 1

Zachary Garrett
Zachary Garrett

Reputation: 2941

The new dataset Python class will need to support serialization. This is necessary because TensorFlow Federated is designed to be run on the machines that are not necessary the same as the machine that wrote the computation (e.g. smartphones in the case of cross-device federated learning). These machines may not be running Python, and hence not understand the new subclass that is created, hence the serialization layer would need to be updated. However, this is pretty low-level and there maybe alternative ways to achieve the desired goal.

Going out on a limb: If the goal is to provide metadata along with the dataset for a client, it maybe easier to alter the function signature of the iterative process returned by fed_avg_schedule.build_fed_avg_process to accept a tuple of (dataset, metadata structure) for each client.

Currently the signature of the next computation is (in TFF type shorthand introduced in Custom Federated Algorithms, Part 1: Introduction to the Federated Core):

(<ServerState@SERVER, Datasets@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)

(Definitions of ServerState. Dataset and Metrics are defined by the model and dataset)

Instead, we might want a signature that looks like:

(<ServerState@SERVER, <Datasets, Metadata>@CLIENTS> -> <ServerState@SERVER, Metrics@SERVER>)

To achieve this, we can perform the following:

  1. Update the types of the arguments on run_one_round here to be a tuple of tf_dataset_type and the metadata structure.
  2. Plubm the new argument through the tff.federated_map call here
  3. Add a new argument to client_update_fn here

Upvotes: 2

Related Questions