[MXNET-58]Layer Normalization in C++#10029
Conversation
|
Here's the new doc of InstanceNorm http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-10029/4/api/python/gluon/nn.html#mxnet.gluon.nn.InstanceNorm @zhanghang1989 |
|
@sxjscience fantastic, thank you! We will definitely try this as soon as its available! |
|
Does anyone has time to review it? The doc page of the latest build is in http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-10029/7/index.html |
|
The docs look good to me 👍 |
| using namespace mshadow; | ||
| CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; | ||
| const TShape &dshape = in_shape->at(layernorm::kData); | ||
| int axis = param.axis; |
|
|
||
| def test_layer_norm(): | ||
| for dtype in [np.float16, np.float32, np.float64]: | ||
| check_layer_normalization((10, 12, 5), -1, 1E-3) |
There was a problem hiding this comment.
Is any axis allowed?
Can you check all possiblities (even if they theoretically overlap))? -2, -1, 0, 1, 2 (for 3D)
How about 1D and 2D? Are those relevant for this operator?
| check_l2_normalization((nbatch, nchannel, height, width), mode) | ||
|
|
||
|
|
||
| def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): |
There was a problem hiding this comment.
Can this be a nested function in check_layer_normalization?
| exe.arg_dict['beta'][:] = beta | ||
| out_nd = exe.forward()[0] | ||
| out = npy_layer_norm(data, gamma, beta, axis, eps) | ||
| assert_allclose(out, out_nd.asnumpy(), 1E-4, 1E-4) |
There was a problem hiding this comment.
Is this the correctness test?
There was a problem hiding this comment.
Yes, it compares it with a numpy version.
| check_layer_normalization((10, 12, 5), -1, 1E-3) | ||
| check_layer_normalization((10, 12, 5), 0, 1E-3) | ||
| check_layer_normalization((10, 12, 5), 1, 1E-3) | ||
| for in_shape in [(10, 6, 5), (5, 5), (2, 3, 3, 3)]: |
| beta_initializer='zeros', gamma_initializer='ones', | ||
| in_channels=0, prefix=None, params=None): | ||
| super(LayerNorm, self).__init__(prefix=prefix, params=params) | ||
| self._kwargs = {'eps': epsilon, 'axis': axis} |
src/operator/nn/layer_norm-inl.h
Outdated
| DMLC_DECLARE_FIELD(axis).set_default(-1) | ||
| .describe("The axis to perform layer normalization. " | ||
| "Usually, this should be be axis of the channel dimension. " | ||
| "Negative values means indexing from right to left. "); |
src/operator/nn/layer_norm-inl.h
Outdated
| DMLC_DECLARE_FIELD(eps).set_default(1e-5f) | ||
| .describe("An `epsilon` parameter to prevent division by 0."); | ||
| DMLC_DECLARE_FIELD(output_mean_var).set_default(false) | ||
| .describe("Output the mean and std calculated along the given axis"); |
|
Do you have any benchmarks regarding statement 1? |
|
@marcoabreu Yes, here the benchmark result. My reference implementation is the following LayerNorm that is implemented by stacking broadcasting/reducing operators: class LayerNormStackSmallOp(HybridBlock):
"""Applies layer normalization to the n-dimensional input array.
Stack bcast/reduce
"""
def __init__(self, axis=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
in_channels=0, prefix=None, params=None):
super(LayerNormStackSmallOp, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'axis': axis}
self._axis = axis
self._epsilon = epsilon
self._center = center
self._scale = scale
assert in_channels != 0, "in_channels == 0 is currently not supported"
if self._center:
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
if self._scale:
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
def moments(self, F, data):
mean = F.mean(data=data, axis=self._axis, keepdims=True)
var = F.mean(F.square(F.broadcast_minus(data, mean)),
axis=self._axis, keepdims=True)
return mean, var
def hybrid_forward(self, F, data, gamma, beta):
if not self._center and not self._scale:
return data
mean, var = self.moments(F, data)
norm_data = F.broadcast_minus(data, mean)
norm_data = F.broadcast_mul(norm_data, mx.sym.rsqrt(var + self._epsilon))
norm_data = F.broadcast_mul(norm_data, gamma)
norm_data = F.broadcast_add(norm_data, beta)
return norm_dataI run the layer normalization on data with shape=(128, 1024, 100), axis=-1
|
|
Great numbers, thanks a lot. Good job!
… |
* add layer_norm + fix batch_norm doc * add test * add layer normaliation in Gluon * update * fix __repr__ + lint * fix doc * fix threshold * fix doc * fix bug * enable inplace + fix test * try to fix test * fix doc
* add layer_norm + fix batch_norm doc * add test * add layer normaliation in Gluon * update * fix __repr__ + lint * fix doc * fix threshold * fix doc * fix bug * enable inplace + fix test * try to fix test * fix doc
* add layer_norm + fix batch_norm doc * add test * add layer normaliation in Gluon * update * fix __repr__ + lint * fix doc * fix threshold * fix doc * fix bug * enable inplace + fix test * try to fix test * fix doc
|
Is there a way to infer the in_channels? I am implementing Scale layer, which has the same problem. assert in_channels != 0, "in_channels == 0 is currently not supported" |
|
@marvis, Would you submit an issue describing the problem with some examples? I’ve rechecked the code and the LayerNorm layer should support in_channels=0.
…________________________________
From: Xingjian SHI
Sent: Wednesday, September 5, 2018 9:53:59 PM
To: apache/incubator-mxnet; apache/incubator-mxnet
Cc: Mention
Subject: Re: [apache/incubator-mxnet] [MXNET-58]Layer Normalization in C++ (#10029)
Currently no. I’ll try to support it soon.
Get Outlook for iOS<https://aka.ms/o0ukef>
________________________________
From: marvis <notifications@github.com>
Sent: Wednesday, September 5, 2018 9:02:55 PM
To: apache/incubator-mxnet
Cc: Xingjian SHI; Mention
Subject: Re: [apache/incubator-mxnet] [MXNET-58]Layer Normalization in C++ (#10029)
Is there a way to infer the in_channels? I am implementing Scale layer, which has the same problem.
assert in_channels != 0, "in_channels == 0 is currently not supported"
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub<#10029 (comment)>, or mute the thread<https://github.com/notifications/unsubscribe-auth/AE8D7ke5jju2gzlQ_YJGXEc3dx9Y7n4Uks5uX8t_gaJpZM4ShxnT>.
|
Description
Checklist
Essentials
make lint)Changes
Comments
We can improve the speed further by fusing the operators. This is left as future work.