Reputation: 2109
TensorFlow newbie here, training on a simple tutorial which I just fail. The point is to convert an image to grayscale.
Our data is basically an HxWx3
(height of the picture, width, and color on three values r,g,b).
So it might be equivalent to transform each array cell from [r, g, b]
to [gray, gray, gray]
where gray = mean(r, g, b)
right?
Thus I checked the doc for a mean function and found reduce_mean. I used it on the color axis, i.e. axis=2, then concat the result on itself using axis 2 again to "replicate" mean value and finally get 3 times the gray value (=mean) as red, green and blue.
See the code below:
import tensorflow as tf
import matplotlib.image as mpimg
filename = "MarshOrchid.jpg"
raw_image_data = mpimg.imread(filename)
image = tf.placeholder("uint8", [None, None, 3])
# Reduce axis 2 by mean (= color)
# i.e. image = [[[r,g,b], ...]]
# out = [[[ grayvalue ], ... ]] where grayvalue = mean(r, g, b)
out = tf.reduce_mean(image, 2, keep_dims=True)
# Associate r,g,b to the same mean value = concat mean on axis 2.
# out = [[[ grayvalu, grayvalue, grayvalue], ...]]
out = tf.concat(2, [out, out, out])
with tf.Session() as session:
result = session.run(out, feed_dict={image: raw_image_data})
print(result.shape)
plt.imshow(result)
plt.show()
(You can get original image here)
This code can be executed but the result isn't ok.
Wondering what happened I check my variables, and it turns out that the mean isn't ok, has shown on screenshot below, mean(147, 137, 88) != 38
Any ideas? Can't figure out what I did wrong...
Thanks! pltrdy
Upvotes: 1
Views: 592
Reputation: 201
Change dtype before computing mean (because of overflow):
The error come from the dtype of your placeholder. Cause the type inference, intermediate tensors cannot have values greater than 255 (2^8-1). When Tensorflow compute mean(147, 137, 88), first it compute : sum(147, 137, 88)=372, but 372>256 so it keep 372% 256 = 116.
And so mean(147, 137, 88) = sum(147, 137, 88)/3 = 116/3 = 40. Change the dtype of your placeholder to "uint16" or "uint32".
Result when switching to uint16 (not really convincing right?):
Change dtype back to uint8 before plotting it to fit pyplot spec:
(see lib doc about imshow) Mentions that it must be uint8. For some reasons, using uint16 does not work (and it looks like it reverse color. I mean dark area are white in previous grayscale transformation. Not sure why).
Turning back to uint_ using tf.cast
just before running (e.g. out = tf.cast(out, tf.uint8)
) gives the good grayscale transformation below:
Upvotes: 2