Reputation: 481
Judging on the examples on the TensorFlow website: https://github.com/tensorflow/docs/blob/r1.15/site/en/guide/distribute_strategy.ipynb, it seems there are no resources on how to adapt your code to using distribute strategy. My original code includes manipulating Tensors, for example tf.expand_dims(x, axis=1)
. However, when distribute strategy is used, I got the abovementioned error as expand_dims()
is not able to work on PerReplica
object. More details of the error below:
Contents: PerReplica:{ 0 /replica:0/task:0/device:GPU:0: Tensor("IteratorGetNext:0", shape=(?, 2, 3), dtype=float32, device=/replica:0/task:0/device:GPU:0), 1 /replica:0/task:0/device:GPU:1: Tensor("IteratorGetNext_1:0", shape=(?, 2, 3), dtype=float32, device=/replica:0/task:0/device:GPU:1) }
Anyone has any idea to any solution to this?
Upvotes: 1
Views: 1490
Reputation: 113
PerReplica
object is usually returned by running strategy.experimental_run_v2/run(...)
, you can think it as a special dict that wraps these message pairs together:
{i-th GPU name: tensors returned by i-th GPU},for i in your all available devices. It looks like a dict but not a real dict, class PerReplica
defines additional methods/properties here for many use cases, e.g, reduce tensors cross devices under distributed context. For your case:
x = strategy.experimental_run_v2(...)
if strategy.num_replicas_in_sync > 1: # x is PerReplica object for multi-devices
tensors_list = x.values # a list [x_from_dev_a, x_from_dev_b, ...]
y = tf.concat(tensors_list, axis=0) # axis=0 at batch dim
else:
y = x # x is a tensor as it is for single device
tf.expand_dims(y, axis=1)
Upvotes: 2