Reputation: 33
I'm trying to define a custom data augmentation layer. My goal is to call the existing tf.keras.layers.RandomZoom, with a probability.
This is what I did:
class random_zoom_layer(tf.keras.layers.Layer):
def __init__(self, probability=0.5, **kwargs):
super().__init__(**kwargs)
self.probability = probability
def call(self, x):
if tf.random.uniform([]) < self.probability:
return tf.keras.layers.RandomZoom(height_factor=(-0.1, 0.1), width_factor=(-0.1, 0.1), fill_mode='constant')(x)
else:
return x
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.Normalization(),
random_zoom_layer(probability=0.2)
])
But during training, I receive this error:
tensorflow.python.framework.errors_impl.NotFoundError: 2 root error(s) found.
(0) NOT_FOUND: 2 root error(s) found.
(0) NOT_FOUND: Resource localhost/_AnonymousVar10/class tensorflow::Var does not exist.
[[{{node sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform/RngReadAndSkip}}]]
[[sequential_1/random_zoom_layer/cond/then/_0/sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform_1/RngReadAndSkip/_15]]
(1) NOT_FOUND: Resource localhost/_AnonymousVar10/class tensorflow::Var does not exist.
[[{{node sequential_1/random_zoom_layer/cond/random_zoom/stateful_uniform/RngReadAndSkip}}]]
I would really appreciate some help!
Upvotes: 1
Views: 836
Reputation: 26708
Maybe you could try something like this:
import tensorflow as tf
class random_zoom_layer(tf.keras.layers.Layer):
def __init__(self, probability=0.5, **kwargs):
super().__init__(**kwargs)
self.probability = probability
self.layer = tf.keras.layers.RandomZoom(height_factor=(-0.1, 0.1), width_factor=(-0.1, 0.1), fill_mode='constant')
def call(self, x):
return tf.cond(tf.less(tf.random.uniform([]), self.probability), lambda: self.layer(x), lambda: x)
data_augmentation = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.Normalization(),
random_zoom_layer(probability=0.2)
])
print(data_augmentation(tf.random.normal((1, 32, 32, 4))))
import matplotlib.pyplot as plt
image = tf.random.normal((1, 32, 32, 4))
plt.imshow(tf.squeeze(image, axis=0))
plt.imshow(tf.squeeze(data_augmentation(tf.random.normal((1, 32, 32, 4))), axis=0))
Upvotes: 3