Reputation: 213
I'm Trying to convert some Keras (TensorFlow) code to Pytorch, and I'm unable to reproduce the MaxPooling3d in Keras (TensorFlow) as MaxPool3d in PyTorch.
The following code:
import torch
import torch.nn as nn
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt
kernel_size = (10, 10, 2)
strides = (32, 32, 2)
in_tensor = torch.randn(1, 1, 256, 256, 64)
tf_out = layers.MaxPooling3D(data_format='channels_first', pool_size=kernel_size,
strides=strides, padding='same')(in_tensor.detach().numpy())
pt_out = nn.MaxPool3d(kernel_size=kernel_size, stride=strides)(in_tensor)
fig = plt.figure(figsize=(10, 5))
axs = fig.subplots(1,2)
axs[0].matshow(pt_out[0,0,:,:,0].detach().numpy())
axs[0].set_title('PyTorch')
axs[1].matshow(tf_out.numpy()[0,0,:,:,0])
axs[1].set_title('TensorFlow')
Gives very different results:
What could be the problem?
Is the padding in the PyTorch version inorrect?
Upvotes: 0
Views: 528
Reputation: 2066
The padding is not the same in both layers, that's why you're not getting the same results.
You set padding='same'
in tensorflow MaxPooling3D
layer, but there is no padding set in pytorch MaxPool3d
layer.
Unfortunately, in Pytorch, there is no option for 'same' padding for MaxPool3d
as in tensorflow. So, you will need to manually pad the tensor before passing it to the pytorch MaxPool3d
layer.
Try this code:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt
kernel_size = (10, 10, 2)
strides = (32, 32, 2)
in_tensor = torch.randn(1, 1, 256, 256, 64)
tf_out = layers.MaxPooling3D(data_format='channels_first', pool_size=kernel_size,
strides=strides)(in_tensor.detach().numpy())
in_tensor = F.pad(in_tensor, (0, 0, 0, 0))
pt_out = nn.MaxPool3d(kernel_size=kernel_size, stride=strides)(in_tensor)
fig = plt.figure(figsize=(10, 5))
axs = fig.subplots(1,2)
axs[0].matshow(pt_out[0,0,:,:,0].detach().numpy())
axs[0].set_title('PyTorch')
axs[1].matshow(tf_out.numpy()[0,0,:,:,0])
axs[1].set_title('TensorFlow')
Output:
Upvotes: 1