tidy
tidy

Reputation: 5097

Why Reshape and Permute for segmentation with unet?

I am doing the image semantic segmentation job with unet. I am confused with the last layers for pixel classification. The Unet code is like this:

...
reshape = Reshape((n_classes,self.img_rows * self.img_cols))(conv9)
permute = Permute((2,1))(reshape)
activation = Activation('softmax')(permute)
model = Model(input = inputs, output = activation) 
return model
...

Can I just reshape without using Permute like this?

reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9)

Updated:

I found the training result is not right when when using the directly reshape way:

reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9) // the loss is not convergent

My groundtruth is generated like this:

X = []
Y = []
im = cv2.imread(impath)
X.append(im)
seg_labels = np.zeros((height, width, n_classes))
for spath in segpaths:
    mask = cv2.imread(spath, 0)
    seg_labels[:, :, c] += mask
Y.append(seg_labels.reshape(width*height, n_classes))

Why reshape directly does not work?

Upvotes: 0

Views: 596

Answers (3)

pitfall
pitfall

Reputation: 2621

You clearly misunderstand the meaning of each operation and the final goal:

  • final goal: classification for each pixel, i.e. softmax along the semantic class axis
  • how to achieve this goal in the original code? Let's see the code line by line:
reshape = Reshape((n_classes,self.img_rows * self.img_cols))(conv9) # L1
permute = Permute((2,1))(reshape) # L2
activation = Activation('softmax')(permute) # L3
  • L1's output dim = n_class-by-n_pixs, (n_pixs=img_rows x img_cols)
  • L2's output dim = n_pixs-by-n_class
  • L3's output dim = n_pixs-by-n_class
  • Note the default softmax activation is applied to the last axis, i.e. the axis that n_class stands for, which is the semantic class axis.

Therefore, this original code fulfills the final goal of semantic segmentation.


Let's revisit the code that you want to change, which is

reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9) # L4
  • L4's output dim = n_pixs-by-n_class

My guess is that you think L4's output dim matches L2's, and thus L4 is a short-cut that is equivalent to executing L1 and L2.

However, matching the shape does not necessarily mean matching the physical meaning of axes. Why? A simple example will explain.

Say you have 2 semantic classes and 3 pixels. To see the difference assume all three pixels belong to the same class.

In other words, a ground truth tensor will look like this

# cls#1 cls#2
[   [0, 1], # pixel #1
    [0, 1], # pixel #2
    [0, 1], # pixel #3
]

Assume you have a perfect network and generate the exact response for each pixel, but your solution will create a tensor like below

# cls#1 cls#2
[   [0, 0], # pixel #1
    [0, 1], # pixel #2
    [1, 1], # pixel #3
]

whose shape is the same as the ground truth's, but fails to match the physical meaning of axes.

This further makes the softmax operation meaningless, because it is supposed to apply to the class dimension, but this dimension does not physically exist. As a result, it leads to the following erroneous output after applying softmax,

# cls#1 cls#2
[   [0.5, 0.5], # pixel #1
    [0, 1], # pixel #2
    [0.5, 0.5], # pixel #3
]

which completely mess up the training even if it is under the ideal assumption.


Therefore, it is a good habit to write down the physical meaning of each axis of a tensor. When you do any tensor reshape operation, ask yourself whether the physical meaning of an axis is changed in your expected way.

For example, if you have a tensor T of shape batch_dim x img_rows x img_cols x feat_dim, you can do many things and not all of them make sense (due to the problematic physical meaning of axes)

  1. (Wrong) reshape it to whatever x feat_dim, because whatever dimension is meaningless in testing where the batch_size might be different.
  2. (Wrong) reshape it to batch_dim x feat_dim x img_rows x img_cols, because the 2nd dimension is NOT the feature dimension and neither for the 3rd and 4th dimension.
  3. (Correct) permute axes (3,1,2), and this will lead you the tensor of shape batch_dim x feat_dim x img_rows x img_cols, while keeping the physical meaning of each axis.
  4. (Correct) reshape it to batch_dim x whatever x feat_dim. This is also valid, because the whatever=img_rows x img_cols is equivalent to the pixel location dimension, and both the meanings of batch_dim and feat_dim are unchanged.

Upvotes: 2

Manoj Mohan
Manoj Mohan

Reputation: 6044

The Reshape and Permute is done to take the softmax at each pixel location. Adding to @meowongac's answer, Reshape preserves the order of the elements. In this case, since the channel dimensions have to be swapped, Reshape followed by Permute is appropriate.

Considering the case of (2,2) image with 3 values at each location,

arr = np.array([[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]]]) 
>>> arr.shape
(3, 2, 2)
>>> arr
array([[[1, 1],
        [1, 1]],

       [[2, 2],
        [2, 2]],

       [[3, 3],
        [3, 3]]])

>>> arr[:,0,0]
array([1, 2, 3])

The channel values at each location are [1,2,3]. The goal is to swap the channel axis(length 3) to the end.

>>> arr.reshape((2,2,3))[0,0] 
array([1, 1, 1])   # incorrect

>>> arr.transpose((1,2,0))[0,0] # similar to what permute does.
array([1, 2, 3])  # correct 

More examples at this link: https://discuss.pytorch.org/t/how-to-change-shape-of-a-matrix-without-dispositioning-the-elements/30708

Upvotes: 1

meowongac
meowongac

Reputation: 720

Your code will still be runnable since the shape will be the same, but the result (backprops) will be different since the values of tensors will be different. For example:

arr = np.array([[[1,1,1],[1,1,1]],[[2,2,2],[2,2,2]],[[3,3,3],[3,3,3]],[[4,4,4],[4,4,4]]])
arr.shape
>>>(4, 2, 3)

#do reshape, then premute
reshape_1 = arr.reshape((4, 2*3))
np.swapaxes(reshape_1, 1, 0)
>>>array([[1, 2, 3, 4],
          [1, 2, 3, 4],
          [1, 2, 3, 4],
          [1, 2, 3, 4],
          [1, 2, 3, 4],
          [1, 2, 3, 4]])

#do reshape directly
reshape_2 = arr.reshape(2*3, 4)
reshape_2
>>>array([[1, 1, 1, 1],
          [1, 1, 2, 2],
          [2, 2, 2, 2],
          [3, 3, 3, 3],
          [3, 3, 4, 4],
          [4, 4, 4, 4]])

Upvotes: 2

Related Questions