Reputation: 165
Here is batch norm in TF:
model = BatchNormalization(momentum=0.15, axis=-1)(model)
And here is batch norm in Torch:
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
You can see, there is one more parameter: num_features
. It's very annoying.
Suppose I don't want affine
in torch, the batch norm in TF and Torch should be the same. Is there a way to avoid specify "num_features" in batch norm of PyTorch, just like Tensorflow?
Upvotes: 1
Views: 549
Reputation: 3283
If you really hate specifying this parameter you might want to look at lazy batch norm.
Otherwise, you can specify num_features
as whatever you like (None
?), as long as BOTH affine
and track_running_stats
are False
. If you look at the base class for the batch norm functions (available at this link):
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
You can see that num_features
is being used to set the self.weight
and self.bias
when affine
is True, but also the running_mean
and running_std
when track_running_stats
is True.
Upvotes: 1