Reputation: 1753
I am working at my personal image augmentation function in TensorFlow 2.0. More specifically, I wrote a function that returns a randomly zoomed image. Its input is image_batch
, a multidimensional numpy
array with shape:
(no. images, height, width, channel)
which in my specific case is:
(31, 300, 300, 3)
This is the code:
def random_zoom(batch, zoom=0.6):
'''
Performs random zoom of a batch of images.
It starts by zero padding the images one by one, then randomly selects
a subsample of the padded image of the same size of the original.
The result is a random zoom
'''
# Import from TensorFlow 2.0
from tensorflow.image import resize_with_pad, random_crop
# save original image height and width
height = batch.shape[1]
width = batch.shape[2]
# Iterate over every image in the batch
for i in range(len(batch)):
# zero pad the image, adding 25-percent to each side
image_distortion = resize_with_pad(batch[i, :,:,:], int(height*(1+zoom)), int(width*(1+zoom)))
# take a subset of the image randomly
image_distortion = random_crop(image_distortion, size=[height, width, 3], seed = 1+i*2)
# put the distorted image back in the batch
batch[i, :,:,:] = image_distortion.numpy()
return batch
I can then call the function:
new_batch = random_zoom(image_batch)
At this point, something strange happens: the new_batch
of images is just as I expected and I'm satisfied with it... but now also image_batch
, the original input object, has been changed! I don't want that, and I don't understand why that happens.
Upvotes: 1
Views: 162
Reputation: 19885
Well, this line batch[i, :,:,:] = image_distortion.numpy()
modifies the array that is passed as an argument.
Your confusion likely stems from familiarity with another language such as C++ where objects passed as arguments are implicitly copied.
In Python, what happens is what you might call passing by reference. No copies are made unless you want them to be. Therefore, it's not that both new_batch
and image_batch
are modified; they are two names pointing to the same object that was changed.
Accordingly, you might want to do something like batch = batch.copy()
at the start of your function.
Upvotes: 2