Patrick Lee
Patrick Lee

Reputation: 165

Can we avoid specify "num_features" in batch norm of PyTorch, just like Tensorflow?

Here is batch norm in TF:

model = BatchNormalization(momentum=0.15, axis=-1)(model)

And here is batch norm in Torch:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

You can see, there is one more parameter: num_features. It's very annoying.

Suppose I don't want affine in torch, the batch norm in TF and Torch should be the same. Is there a way to avoid specify "num_features" in batch norm of PyTorch, just like Tensorflow?

Upvotes: 1

Views: 549

Answers (1)

jhso
jhso

Reputation: 3283

If you really hate specifying this parameter you might want to look at lazy batch norm.

Otherwise, you can specify num_features as whatever you like (None?), as long as BOTH affine and track_running_stats are False. If you look at the base class for the batch norm functions (available at this link):

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""

    _version = 2
    __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
    num_features: int
    eps: float
    momentum: float
    affine: bool
    track_running_stats: bool
    # WARNING: weight and bias purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        affine: bool = True,
        track_running_stats: bool = True,
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

You can see that num_features is being used to set the self.weight and self.bias when affine is True, but also the running_mean and running_std when track_running_stats is True.

Upvotes: 1

Related Questions