Reputation: 311
I want to apply skimage’s Local Binary Pattern feature extraction on my data, and was wondering if there was any possibility of doing this inside my torch’s Transforms, which right now is the following:
data_transforms = {
'train': transforms.Compose([
transforms.CenterCrop(178),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'val': transforms.Compose([
transforms.CenterCrop(178),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
}
If not, how would I implement it? Would I have to do it when importing the data?
Upvotes: 0
Views: 774
Reputation: 1317
You can implement the Transform using the lamdba funtion. As @dhananjay correctly pointed out. Building on that comment, the implementation would be as follows:
def lbp(x):
radius = 2
n_points = 8 * radius
METHOD = 'uniform'
lbp = local_binary_pattern(x, n_points, radius, METHOD)
return lbp
data_transforms = {
'train': transforms.Compose([
transforms.CenterCrop(178),
transforms.RandomHorizontalFlip(),
transforms.Lambda(lbp),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'val': transforms.Compose([
transforms.CenterCrop(178),
transforms.Lambda(lbp),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
}
BUT. This is a bad idea because it defeats the very purpose of the pytorch tranform. A transform is ideal for an operation that either 1. Can be computed trivially (at low compute cost) from the original data. Hence there is no advantage to applying it on your data and storing a copy. Normalize is one such transform. 2. Introduces an element of stochasticity or randomn perturbation in the original data. E.g RandomHorizontalFlip etc.
The key thing to remember is that your transform will be applied at every batch to the dataset while learning.
Considering the above, you absolutely do not want to implement your lbp as a transform. It is better to compute it offline and store it. Else you will be significantly slowing down your batch loading.
Upvotes: 1
Reputation: 190
You can use torchvision.transforms.Lambda.
It allows you to apply custom lambda function as a transform.
something like transforms.Lambda(lambda x: local_binary_pattern(x))
Upvotes: 0