Srikanth Sharma
Srikanth Sharma

Reputation: 2057

Image similarity using Tensorflow or PyTorch

I want to compare two images for similarity. Since my purpose is to match a given image against a massive collection of images, I want to run the comparisons on GPU.

I came across tf.image.ssim and tf.image.psnr functions but I am unable to find and working examples only. The solutions in PyTorch is also appreciated. Since I don't have a good understanding of CUDA and C language, I am hesitant to try kernels in PyCuda.

Will it be helpful in terms of processing if I read the entire image collection and store as Tensorflow Records for future processing?

Any guidance or solution, greatly appreciated. Thank you.

Edit:- I am matching images of same size only. I don't want to do mere histogram match. I want to do SSIM or PSNR implementation for image similarity. So, I am assuming it would be similar in color, content etc

Upvotes: 1

Views: 5252

Answers (2)

Donshel
Donshel

Reputation: 624

There is no implementation of PSNR or SSIM in PyTorch. You can either implement them yourself or use a third-party package, like piqa which I have developed.

Assuming you already have torch and torchvision installed, you can get it with

pip install piqa

Then for the image comparison

import torch

from torchvision import transforms
from PIL import Image

im1 = Image.open('path/to/im1.png')
im2 = Image.open('path/to/im2.png')

transform = transforms.ToTensor()

x = transform(im1).unsqueeze(0).cuda() # .cuda() for GPU
y = transform(im2).unsqueeze(0).cuda()

from piqa import PSNR, SSIM

psnr = PSNR()
ssim = SSIM().cuda()

print('PSNR:', psnr(x, y))
print('SSIM:', ssim(x, y))

Upvotes: 2

Sorin
Sorin

Reputation: 11968

Check out the example on the tensorflow doc page (link):

im1 = tf.decode_png('path/to/im1.png')
im2 = tf.decode_png('path/to/im2.png')
print(tf.image.ssim(im1, im2, max_val=255))

This should work on latest version of tensorflow. If you use older versions tf.image.ssim will return a tensor (print will not give you a value), but you can call .run() to evaluate it.

Upvotes: 1

Related Questions