Iamnotperfect
Iamnotperfect

Reputation: 119

Is it possible to add own function in transform.compose in pytorch

I am using a pre-trained Alex model. I am running this model on some random image dataset. I want to convert RGB images to YCbCr images before training.

I am wondering is it possible to add a function on my own to transform.compose, For example:

transform = transforms.Compose([
  ycbcr(), #something like this
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

where,

def ycbcr(img):
   img = cv2.imread(img)  
   img = cv2.cvtColor(img, cv2.COLOR_BGR2ycbcr)
   t = torch.from_numpy(img)
 return t

training_dataset = datasets.ImageFolder(link_train ,transform = transform_train)

training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=96, shuffle=True)

Is this process correct? Please help me on how to proceed?

Upvotes: 2

Views: 5588

Answers (1)

David
David

Reputation: 8318

You can pass a custom transformation to torchvision.transform by defining a class.

To understand better I suggest that you read the documentations.

In your case it will be something like the following:

class ycbcr(object):
    def __call__(self, img):
        """
        :param img: (PIL): Image 

        :return: ycbr color space image (PIL)
        """
        img = cv2.imread(img)  
        img = cv2.cvtColor(img, cv2.COLOR_BGR2ycbcr)
        # t = torch.from_numpy(img)

        return Image.fromarray(t)

    def __repr__(self):
        return self.__class__.__name__+'()'

Notice that it gets a PIL image and return a PIL image. So you might want to adjust your code properly. But this is the general way to define a custom transformation.

Upvotes: 7

Related Questions