Reputation: 395
I'm trying to get up and running the cats and dogs example on keras but so far without success.
Found 23410 files belonging to 2 classes.
Using 4682 files for validation.
2021-02-19 10:05:56.625856: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-02-19 10:05:56.640618: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2801090000 Hz
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
And this is the code:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import os
num_skipped = 0
total = 0
for folder_name in ("Cat", "Dog"):
folder_path = os.path.join("PetImages", folder_name)
for fname in os.listdir(folder_path):
fpath = os.path.join(folder_path, fname)
try:
total += 1
fobj = open(fpath, "rb")
is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
finally:
fobj.close()
if not is_jfif:
num_skipped += 1
# Delete corrupted image
os.remove(fpath)
print("Total %d Deleted %d images" % (total, num_skipped) )
image_size = (180, 180)
batch_size = 32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages",
validation_split=0.2,
subset="training",
seed=1337,
image_size=image_size,
batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages",
validation_split=0.2,
subset="validation",
seed=1337,
image_size=image_size,
batch_size=batch_size,
)
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(int(labels[i]))
plt.axis("off")
Any idea how to proceed further with this? Maybe it's related to installed version of pythong, keras and tensorflow?
Upvotes: 0
Views: 495
Reputation: 1
Not a solution, just information -
This issue is only faced with tf.keras.preprocessing.image_dataset_from_directory
or tf.keras.utils.image_dataset_from_directory
.
No issue faced when tf.keras.preprocessing.image.ImageDataGenerator
is used. Unfortunately, ImageDataGenerator is deprecated.
Upvotes: 0
Reputation: 68
Also had the same issue. To resolve this simply add:
plt.show()
after the for statement. e.g
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(int(labels[i]))
plt.axis("off")
plt.show()
Upvotes: 2