Reputation: 325
I am trying to convert MNIST dataset to RGB format, the actual shape of each image is (28, 28), but i need (28, 28, 3).
import numpy as np
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()
X = np.concatenate([x_train, x_test])
X = X / 127.5 - 1
X.reshape((70000, 28, 28, 1))
tf.image.grayscale_to_rgb(
X,
name=None
)
But i get the following error:
ValueError: Dimension 1 in both shapes must be equal, but are 84 and 3. Shapes are [28,84] and [28,3].
Upvotes: 5
Views: 7082
Reputation: 1612
In addition to @DMolony and @Aqwis01 answers, another simple solution could be using numpy.repeat
method to duplicate the last dimension of your tensor several times:
X = X.reshape((70000, 28, 28, 1))
X = X.repeat(3, -1) # repeat the last (-1) dimension three times
X_t = tf.convert_to_tensor(X)
assert X_t.shape == (70000, 28, 28, 3)
Upvotes: 0
Reputation: 66
You should store the reshaped 3D [28x28x1] images in an array:
X = X.reshape((70000, 28, 28, 1))
When converting, set an other array to the return value of the tf.image.grayscale_to_rgb()
function :
X3 = tf.image.grayscale_to_rgb(
X,
name=None
)
Finally, to plot out one example from the resulting tensor images with matplotlib
and tf.session()
:
import matplotlib.pyplot as plt
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
image_to_plot = sess.run(image)
plt.figure()
plt.imshow(image_to_plot)
plt.grid(False)
The complete code:
import numpy as np
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()
X = np.concatenate([x_train, x_test])
X = X / 127.5 - 1
# Set reshaped array to X
X = X.reshape((70000, 28, 28, 1))
# Convert images and store them in X3
X3 = tf.image.grayscale_to_rgb(
X,
name=None
)
# Get one image from the 3D image array to var. image
image = X3[0,:,:,:]
# Plot it out with matplotlib.pyplot
import matplotlib.pyplot as plt
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
image_to_plot = sess.run(image)
plt.figure()
plt.imshow(image_to_plot)
plt.grid(False)
Upvotes: 4
Reputation: 643
If you print the shape of X before tf.image.grayscale_to_rgb you will see the output dimension is (70000, 28, 28). Inputs to tf.image.grayscale must have size 1 as it's final dimension.
Expand the final dimension of X to make it compatible with the function
tf.image.grayscale_to_rgb(tf.expand_dims(X, axis=3))
Upvotes: 0