Skip to content

Commit 75e2ae5

Browse files
committed
Limit MKLDNN ops being used.
1 parent 53eec60 commit 75e2ae5

File tree

5 files changed

+52
-14
lines changed

5 files changed

+52
-14
lines changed

src/operator/nn/activation.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ inline static bool ActivationStorageType(const nnvm::NodeAttrs& attrs,
9898
CHECK_EQ(out_attrs->size(), 1);
9999
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
100100
#if MXNET_USE_MKLDNN == 1
101-
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) {
101+
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)
102+
// There is no reason to use MKLDNN activation if the input isn't in
103+
// MKLDNN format.
104+
&& in_attrs->at(0) == kMKLDNNStorage) {
102105
*dispatch_mode = DispatchMode::kFComputeEx;
103106
(*out_attrs)[0] = kMKLDNNStorage;
104107
return true;
@@ -121,7 +124,10 @@ inline static bool backward_ActStorageType(const nnvm::NodeAttrs& attrs,
121124
CHECK_EQ(out_attrs->size(), 1U);
122125
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
123126
#if MXNET_USE_MKLDNN == 1
124-
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) {
127+
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)
128+
// There is no reason to use MKLDNN activation if the input isn't in
129+
// MKLDNN format.
130+
&& in_attrs->at(0) == kMKLDNNStorage) {
125131
*dispatch_mode = DispatchMode::kFComputeEx;
126132
(*out_attrs)[0] = kMKLDNNStorage;
127133
return true;

src/operator/nn/convolution.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,22 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
293293
}
294294

295295
inline static bool ConvStorageType(const nnvm::NodeAttrs& attrs,
296-
const int dev_mask,
297-
DispatchMode* dispatch_mode,
298-
std::vector<int> *in_attrs,
299-
std::vector<int> *out_attrs) {
296+
const int dev_mask,
297+
DispatchMode* dispatch_mode,
298+
std::vector<int> *in_attrs,
299+
std::vector<int> *out_attrs) {
300300
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
301301
uint32_t in_expected = param.no_bias ? 2 : 3;
302302
CHECK_EQ(in_attrs->size(), in_expected);
303303
CHECK_EQ(out_attrs->size(), 1);
304304

305305
#if MXNET_USE_MKLDNN == 1
306-
if (dev_mask == mshadow::cpu::kDevMask) {
306+
if (dev_mask == mshadow::cpu::kDevMask
307+
// We should allow MKLDNN conv to apply to the default storage as well.
308+
// Even with format conversion, MKLDNN conv should still be faster than
309+
// the native implementation.
310+
&& (in_attrs->at(0) == kMKLDNNStorage
311+
|| in_attrs->at(0) == kDefaultStorage)) {
307312
*dispatch_mode = DispatchMode::kFComputeEx;
308313
(*out_attrs)[0] = kMKLDNNStorage;
309314
return true;
@@ -326,7 +331,12 @@ inline static bool backward_ConvStorageType(const nnvm::NodeAttrs& attrs,
326331
CHECK_EQ(out_attrs->size(), out_expected);
327332

328333
#if MXNET_USE_MKLDNN == 1
329-
if (dev_mask == mshadow::cpu::kDevMask) {
334+
if (dev_mask == mshadow::cpu::kDevMask
335+
// We should allow MKLDNN conv to apply to the default storage as well.
336+
// Even with format conversion, MKLDNN conv should still be faster than
337+
// the native implementation.
338+
&& (in_attrs->at(0) == kMKLDNNStorage
339+
|| in_attrs->at(0) == kDefaultStorage)) {
330340
*dispatch_mode = DispatchMode::kFComputeEx;
331341
for (size_t i = 0; i < out_attrs->size(); i++)
332342
(*out_attrs)[i] = kMKLDNNStorage;

src/operator/nn/deconvolution.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,12 @@ inline static bool DeconvStorageType(const nnvm::NodeAttrs& attrs,
267267
CHECK_EQ(out_attrs->size(), 1);
268268

269269
#if MXNET_USE_MKLDNN == 1
270-
if (dev_mask == mshadow::cpu::kDevMask) {
270+
if (dev_mask == mshadow::cpu::kDevMask
271+
// We should allow MKLDNN conv to apply to the default storage as well.
272+
// Even with format conversion, MKLDNN conv should still be faster than
273+
// the native implementation.
274+
&& (in_attrs->at(0) == kMKLDNNStorage
275+
|| in_attrs->at(0) == kDefaultStorage)) {
271276
*dispatch_mode = DispatchMode::kFComputeEx;
272277
(*out_attrs)[0] = kMKLDNNStorage;
273278
return true;
@@ -293,7 +298,12 @@ inline static bool backward_DeconvStorageType(const nnvm::NodeAttrs& attrs,
293298
CHECK_EQ(out_attrs->size(), out_expected);
294299

295300
#if MXNET_USE_MKLDNN == 1
296-
if (dev_mask == mshadow::cpu::kDevMask) {
301+
if (dev_mask == mshadow::cpu::kDevMask
302+
// We should allow MKLDNN conv to apply to the default storage as well.
303+
// Even with format conversion, MKLDNN conv should still be faster than
304+
// the native implementation.
305+
&& (in_attrs->at(0) == kMKLDNNStorage
306+
|| in_attrs->at(0) == kDefaultStorage)) {
297307
*dispatch_mode = DispatchMode::kFComputeEx;
298308
for (size_t i = 0; i < out_attrs->size(); i++)
299309
(*out_attrs)[i] = kMKLDNNStorage;

src/operator/nn/fully_connected.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
138138
CHECK_EQ(out_attrs->size(), 1);
139139

140140
#if MXNET_USE_MKLDNN == 1
141-
if (dev_mask == mshadow::cpu::kDevMask) {
141+
// The native implementation uses BLAS. It shouldn't be slower than MKLDNN
142+
// FC. If the input data has the default format, there is format conversion
143+
// overhead as well.
144+
if (dev_mask == mshadow::cpu::kDevMask && in_attrs->at(0) == kMKLDNNStorage) {
142145
*dispatch_mode = DispatchMode::kFComputeEx;
143146
(*out_attrs)[0] = kMKLDNNStorage;
144147
return true;
@@ -160,7 +163,10 @@ inline static bool backward_FCStorageType(const nnvm::NodeAttrs& attrs,
160163
CHECK_EQ(out_attrs->size(), out_expected);
161164

162165
#if MXNET_USE_MKLDNN == 1
163-
if (dev_mask == mshadow::cpu::kDevMask) {
166+
// The native implementation uses BLAS. It shouldn't be slower than MKLDNN
167+
// FC. If the input data has the default format, there is format conversion
168+
// overhead as well.
169+
if (dev_mask == mshadow::cpu::kDevMask && in_attrs->at(0) == kMKLDNNStorage) {
164170
*dispatch_mode = DispatchMode::kFComputeEx;
165171
for (size_t i = 0; i < out_attrs->size(); i++)
166172
(*out_attrs)[i] = kMKLDNNStorage;

src/operator/nn/pooling.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,10 @@ inline static bool PoolingStorageType(const nnvm::NodeAttrs &attrs,
300300

301301
#if MXNET_USE_MKLDNN == 1
302302
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
303-
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) {
303+
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)
304+
// There is no reason to use MKLDNN pooling if the input isn't in
305+
// MKLDNN format.
306+
&& in_attrs->at(0) == kMKLDNNStorage) {
304307
*dispatch_mode = DispatchMode::kFComputeEx;
305308
for (size_t i = 0; i < out_attrs->size(); i++)
306309
(*out_attrs)[i] = kMKLDNNStorage;
@@ -322,7 +325,10 @@ inline static bool backward_PoolingStorageType(const nnvm::NodeAttrs &attrs,
322325

323326
#if MXNET_USE_MKLDNN == 1
324327
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
325-
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) {
328+
if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)
329+
// There is no reason to use MKLDNN pooling if the input isn't in
330+
// MKLDNN format.
331+
&& in_attrs->at(0) == kMKLDNNStorage) {
326332
*dispatch_mode = DispatchMode::kFComputeEx;
327333
for (size_t i = 0; i < out_attrs->size(); i++)
328334
(*out_attrs)[i] = kMKLDNNStorage;

0 commit comments

Comments
 (0)