|
19 | 19 | # pylint: disable= arguments-differ |
20 | 20 | """Basic neural network layers.""" |
21 | 21 | __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding', |
22 | | - 'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda'] |
| 22 | + 'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm', |
| 23 | + 'Flatten', 'Lambda', 'HybridLambda'] |
23 | 24 | import warnings |
24 | 25 | import numpy as np |
25 | 26 |
|
@@ -616,6 +617,94 @@ def __repr__(self): |
616 | 617 | for k, v in self._kwargs.items()])) |
617 | 618 |
|
618 | 619 |
|
| 620 | +class GroupNorm(HybridBlock): |
| 621 | + r""" |
| 622 | + Applies group normalization to the n-dimensional input array. |
| 623 | + This operator takes an n-dimensional input array where the leftmost 2 axis are |
| 624 | + `batch` and `channel` respectively: |
| 625 | +
|
| 626 | + .. math:: |
| 627 | +
|
| 628 | + x = x.reshape((N, num_groups, C // num_groups, ...)) |
| 629 | + axis = (2, ...) |
| 630 | + out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta |
| 631 | +
|
| 632 | + Parameters |
| 633 | + ---------- |
| 634 | + num_groups: int, default 1 |
| 635 | + Number of groups to separate the channel axis into. |
| 636 | + epsilon: float, default 1e-5 |
| 637 | + Small float added to variance to avoid dividing by zero. |
| 638 | + center: bool, default True |
| 639 | + If True, add offset of `beta` to normalized tensor. |
| 640 | + If False, `beta` is ignored. |
| 641 | + scale: bool, default True |
| 642 | + If True, multiply by `gamma`. If False, `gamma` is not used. |
| 643 | + beta_initializer: str or `Initializer`, default 'zeros' |
| 644 | + Initializer for the beta weight. |
| 645 | + gamma_initializer: str or `Initializer`, default 'ones' |
| 646 | + Initializer for the gamma weight. |
| 647 | +
|
| 648 | +
|
| 649 | + Inputs: |
| 650 | + - **data**: input tensor with shape (N, C, ...). |
| 651 | +
|
| 652 | + Outputs: |
| 653 | + - **out**: output tensor with the same shape as `data`. |
| 654 | +
|
| 655 | + References |
| 656 | + ---------- |
| 657 | + `Group Normalization |
| 658 | + <https://arxiv.org/pdf/1803.08494.pdf>`_ |
| 659 | +
|
| 660 | + Examples |
| 661 | + -------- |
| 662 | + >>> # Input of shape (2, 3, 4) |
| 663 | + >>> x = mx.nd.array([[[ 0, 1, 2, 3], |
| 664 | + [ 4, 5, 6, 7], |
| 665 | + [ 8, 9, 10, 11]], |
| 666 | + [[12, 13, 14, 15], |
| 667 | + [16, 17, 18, 19], |
| 668 | + [20, 21, 22, 23]]]) |
| 669 | + >>> # Group normalization is calculated with the above formula |
| 670 | + >>> layer = GroupNorm() |
| 671 | + >>> layer.initialize(ctx=mx.cpu(0)) |
| 672 | + >>> layer(x) |
| 673 | + [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065] |
| 674 | + [-0.4345239 -0.1448413 0.1448413 0.4345239] |
| 675 | + [ 0.7242065 1.0138891 1.3035717 1.5932543]] |
| 676 | + [[-1.5932543 -1.3035717 -1.0138891 -0.7242065] |
| 677 | + [-0.4345239 -0.1448413 0.1448413 0.4345239] |
| 678 | + [ 0.7242065 1.0138891 1.3035717 1.5932543]]] |
| 679 | + <NDArray 2x3x4 @cpu(0)> |
| 680 | + """ |
| 681 | + def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, |
| 682 | + beta_initializer='zeros', gamma_initializer='ones', |
| 683 | + prefix=None, params=None): |
| 684 | + super(GroupNorm, self).__init__(prefix=prefix, params=params) |
| 685 | + self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale} |
| 686 | + self._num_groups = num_groups |
| 687 | + self._epsilon = epsilon |
| 688 | + self._center = center |
| 689 | + self._scale = scale |
| 690 | + self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', |
| 691 | + shape=(num_groups,), init=gamma_initializer, |
| 692 | + allow_deferred_init=True) |
| 693 | + self.beta = self.params.get('beta', grad_req='write' if center else 'null', |
| 694 | + shape=(num_groups,), init=beta_initializer, |
| 695 | + allow_deferred_init=True) |
| 696 | + |
| 697 | + def hybrid_forward(self, F, data, gamma, beta): |
| 698 | + norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon) |
| 699 | + return norm_data |
| 700 | + |
| 701 | + def __repr__(self): |
| 702 | + s = '{name}({content})' |
| 703 | + return s.format(name=self.__class__.__name__, |
| 704 | + content=', '.join(['='.join([k, v.__repr__()]) |
| 705 | + for k, v in self._kwargs.items()])) |
| 706 | + |
| 707 | + |
619 | 708 | class Lambda(Block): |
620 | 709 | r"""Wraps an operator or an expression as a Block object. |
621 | 710 |
|
|
0 commit comments