Reputation: 338
This code generates an error that I don't understand. Can someone explain me please?
import tensorflow as tf
def augment(img):
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
img = tf.expand_dims(img, 0)
return data_augmentation(img)
# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data)
# and augment... -> bug
dataset = dataset.map(augment)
# note that the follwing works
for im in dataset:
augment(im)
and a get
ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.
I tried on Google Colab and have Tensorflow 2.4.1 on my computer. Note that with resize or rescale it works (as it is in this example https://www.tensorflow.org/tutorials/images/data_augmentation but they didn't tried with RandomRotate even if they use it in a loop).
Upvotes: 0
Views: 9900
Reputation: 338
Here is the answer...
import numpy as np
import tensorflow as tf
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data).batch(5)
# and augment... -> bug
dataset = dataset.map(lambda x: data_augmentation(x))
Strange, if we use a lambda function it works, if we define a function which only calls data_augmentation
it fails...
Upvotes: 4
Reputation: 11333
I think you've confused the purpose of tf.keras.layers.experimental.preprocessing.*
. They are to be used in conjunction with your model. So that data augmentation is streamlined with the model it self.
In other words, these layers are a part of your model, not your data pipeline (as you're trying to use it with the dataset.map
for example). If you'd like to use these layers with a tf.data.Dataset
, here's a working example.
import tensorflow as tf
import numpy as np
def augment(img):
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
return data_augmentation(img)
# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data).batch(5)
for d in dataset:
aug_d = augment(d)
Upvotes: 1