Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit eec0fb4

Browse files
haojin2sxjscience
authored andcommitted
Group Normalization (#14959)
* GroupNorm * add to amp list * re-write forward
1 parent b887c06 commit eec0fb4

File tree

7 files changed

+706
-1
lines changed

7 files changed

+706
-1
lines changed

python/mxnet/contrib/amp/lists/symbol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@
471471
'log_softmax',
472472
'InstanceNorm',
473473
'LayerNorm',
474+
'GroupNorm',
474475
'L2Normalization',
475476
'LRN',
476477
'SoftmaxActivation',

python/mxnet/gluon/nn/basic_layers.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# pylint: disable= arguments-differ
2020
"""Basic neural network layers."""
2121
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
22-
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
22+
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
23+
'Flatten', 'Lambda', 'HybridLambda']
2324
import warnings
2425
import numpy as np
2526

@@ -616,6 +617,94 @@ def __repr__(self):
616617
for k, v in self._kwargs.items()]))
617618

618619

620+
class GroupNorm(HybridBlock):
621+
r"""
622+
Applies group normalization to the n-dimensional input array.
623+
This operator takes an n-dimensional input array where the leftmost 2 axis are
624+
`batch` and `channel` respectively:
625+
626+
.. math::
627+
628+
x = x.reshape((N, num_groups, C // num_groups, ...))
629+
axis = (2, ...)
630+
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
631+
632+
Parameters
633+
----------
634+
num_groups: int, default 1
635+
Number of groups to separate the channel axis into.
636+
epsilon: float, default 1e-5
637+
Small float added to variance to avoid dividing by zero.
638+
center: bool, default True
639+
If True, add offset of `beta` to normalized tensor.
640+
If False, `beta` is ignored.
641+
scale: bool, default True
642+
If True, multiply by `gamma`. If False, `gamma` is not used.
643+
beta_initializer: str or `Initializer`, default 'zeros'
644+
Initializer for the beta weight.
645+
gamma_initializer: str or `Initializer`, default 'ones'
646+
Initializer for the gamma weight.
647+
648+
649+
Inputs:
650+
- **data**: input tensor with shape (N, C, ...).
651+
652+
Outputs:
653+
- **out**: output tensor with the same shape as `data`.
654+
655+
References
656+
----------
657+
`Group Normalization
658+
<https://arxiv.org/pdf/1803.08494.pdf>`_
659+
660+
Examples
661+
--------
662+
>>> # Input of shape (2, 3, 4)
663+
>>> x = mx.nd.array([[[ 0, 1, 2, 3],
664+
[ 4, 5, 6, 7],
665+
[ 8, 9, 10, 11]],
666+
[[12, 13, 14, 15],
667+
[16, 17, 18, 19],
668+
[20, 21, 22, 23]]])
669+
>>> # Group normalization is calculated with the above formula
670+
>>> layer = GroupNorm()
671+
>>> layer.initialize(ctx=mx.cpu(0))
672+
>>> layer(x)
673+
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
674+
[-0.4345239 -0.1448413 0.1448413 0.4345239]
675+
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
676+
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
677+
[-0.4345239 -0.1448413 0.1448413 0.4345239]
678+
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
679+
<NDArray 2x3x4 @cpu(0)>
680+
"""
681+
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
682+
beta_initializer='zeros', gamma_initializer='ones',
683+
prefix=None, params=None):
684+
super(GroupNorm, self).__init__(prefix=prefix, params=params)
685+
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
686+
self._num_groups = num_groups
687+
self._epsilon = epsilon
688+
self._center = center
689+
self._scale = scale
690+
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
691+
shape=(num_groups,), init=gamma_initializer,
692+
allow_deferred_init=True)
693+
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
694+
shape=(num_groups,), init=beta_initializer,
695+
allow_deferred_init=True)
696+
697+
def hybrid_forward(self, F, data, gamma, beta):
698+
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
699+
return norm_data
700+
701+
def __repr__(self):
702+
s = '{name}({content})'
703+
return s.format(name=self.__class__.__name__,
704+
content=', '.join(['='.join([k, v.__repr__()])
705+
for k, v in self._kwargs.items()]))
706+
707+
619708
class Lambda(Block):
620709
r"""Wraps an operator or an expression as a Block object.
621710

0 commit comments

Comments
 (0)