Reputation: 524
Im building a GAN on Tensorflow for Image Deblurring, its an implementation of DeblurGANv2. I setup the GAN in a way it have two inputs, a batch of blurred images, and a batch of sharp images. Following this lines, I design the input to be a Python Dictionary with two Keys ['sharp', 'blur']
, each one have a tensor of shape [batch_size, 512, 512, 3]
, this make it easy for feed the blurred images batch to the generator, and then feed the output of generator and the sharp images batch to the discriminator.
Based on the last requirements, i create a tf.data.Dataset
that outputs exactly that, a dict containing the two tensors, each one with their batch dimension. this complements perfectly with my GAN implementation, everything work fine and smoothly.
So keep in mind, my input is not a tensor, but a python dict, that has no batch dimension, this will be relevant for explain my problem later.
Recently, i decided to add support for distributed training using Tensorflow Distribution Strategies. This feature of Tensorflow allows to distribute the training over multiple devices, inclusively over multiple machines. There is a feature with some of the implementations, for example MirroredStrategy
, that takes the input tensor, splits it in equal parts, and feed each slice to different devices, that means, if you have a batch size of 16 and 4 GPUs, each GPU will end taking a local batch of 4 datapoints, after this there is some magic for aggregate the results and other stuff that is not relevant to my problem.
As you already notice, is critical for distribution strategies to have a tensor as input, or at least some sort of input with an exterior batch dimension, and what i have is a Python dict, with the batch dimension of the inputs in the internal dictionary tensor values. This is a huge problem, my current implementation is not compatible with distributed training.
I was looking for workarounds, but i cant wrap my head very well around this, maybe just make the input a huge tensor of shape=[batch_size, 2, 512, 512, 3]
and slice it? not sure this just come to my mind right now lol. Anyways i see this very ambiguous, i cant not differentiate the two inputs, at least not with the clarity of the dictionary keys. Edit: The problem with this solution is that make my dataset transformations very expensive, hence makes the dataset throughput lot slower, taking into account this is an image loading pipeline, this is a major point.
Maybe my explanation of how distributed strategies work is not the most rigorous one, if im not seeing something feel free to correct me pls.
PD: This is not a bug question or a code error, mostly a "System Design Query", hope this is not illegal here
Upvotes: 1
Views: 701
Reputation: 153
Instead of using dictionary as input the GAN, you can try mapping a function in the following way,
def load_image(fileA,fileB):
imageA = tf.io.read_file(fileA)
imageA = tf.image.decode_jpeg(imageA, channels=3)
imageB = tf.io.read_file(fileB)
imageB = tf.image.decode_jpeg(imageB)
return imageA,imageB
trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)
#for mirrored strategy
dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)
You can iterate the dataset and update the network by passing both the images.
I hope this helps !
Upvotes: 1