user8188120
user8188120

Reputation: 915

Loading FITS images with PyTorch

I'm trying to create a CNN using PyTorch but my images need importing from the FITS format rather than conventional .png or .jpeg etc.

Is there a way to accomplish this easily using torch.utils.data.DataLoader or is there a place in the source code where I can put in a clause which will handle FITS files while loading in?

I have looked in the documentation and the most relevant thing I've found is the ToPILImage transformer which converts a tensor or ndarray into a PIL Image.

Currently I'm using an image loading routine as follows:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

Astropy: http://www.astropy.org/

Pytorch: https://pytorch.org/

torch.utils: https://pytorch.org/docs/master/data.html

UPDATE: Perhaps using torchvision.datasets.DatasetFolder instead of DataLoader, an inserting in my own FITS handler would work?

When trying to use this class I get the following error:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

Is DatasetFolder actually supported by torchvision at this point in time?

Upvotes: 0

Views: 1309

Answers (3)

sara
sara

Reputation: 135

I've encountered the same problem as @user8188120 a few weeks ago. Using @Iguananaut's answer works great when reading labels from the folder structure. If someone stumbles upon this and needs reading from csv file, this might also work:

labels = []
transform = transforms.Compose([
    # here go your transforms
    ])


class MyFitsDataset(data.Dataset):
    def __init__(self, csv_path):
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # the rest contain the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1:])  # for multi-label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])  # for single-label
        labels.append(self.label_arr)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]

        data = pyfits.open(single_image_name, axes=2)
        data = data[0].data.astype('float32')
        data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)

        img = transform(data)

        # Get label(class) of the image based on the pandas column
        single_image_label = self.label_arr[index]

        return (img, single_image_label)

    def __len__(self):
        return self.data_len

This also avoids using the DatasetFolder class, which still isn't available in the newest version of PyTorch. I hope this helps someone.

Upvotes: 0

Iguananaut
Iguananaut

Reputation: 23376

From reading some combination of the docs and the code, I don't think you necessarily want to be using ImageFolder since it doesn't know anything about FITS.

Instead you should try using the more generic DataSetFolder class (which in fact is the parent class of ImageFolder). You would pass it a list of extensions it should handle (i.e. ['.fits'] and a "loader" function that takes a FITS file and, it seems, should return a PIL.Image.

You could even make your own subclass following the example of ImageFolder. E.g.

class FitsFolder(DatasetFolder):

    EXTENSIONS = ['.fits']

    def __init__(self, root, transform=None, target_transform=None,
                 loader=None):
        if loader is None:
            loader = self.__fits_loader

        super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                         transform=transform,
                                         target_transform=target_transform)

    @staticmethod
    def __fits_loader(filename):
        data = fits.getdata(filename)
        return Image.fromarray(data)

The exact details of __fits_loader may depend on the details of your FITS files. This basic example just uses the high-level fits.getdata() function which returns the first image array in the FITS file (some FITS files may have many extensions with many images, or have tables etc.). So that part would be up to you.

Upvotes: 3

MadeOfAir
MadeOfAir

Reputation: 3193

You can export a FITS image to any format supported by pyplot.imsave() using this method:

from astropy.io import fits
import matplotlib.pyplot as plt

image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")

Upvotes: 0

Related Questions