@@ -293,17 +293,22 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
293293}
294294
295295inline static bool ConvStorageType (const nnvm::NodeAttrs& attrs,
296- const int dev_mask,
297- DispatchMode* dispatch_mode,
298- std::vector<int > *in_attrs,
299- std::vector<int > *out_attrs) {
296+ const int dev_mask,
297+ DispatchMode* dispatch_mode,
298+ std::vector<int > *in_attrs,
299+ std::vector<int > *out_attrs) {
300300 const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed );
301301 uint32_t in_expected = param.no_bias ? 2 : 3 ;
302302 CHECK_EQ (in_attrs->size (), in_expected);
303303 CHECK_EQ (out_attrs->size (), 1 );
304304
305305#if MXNET_USE_MKLDNN == 1
306- if (dev_mask == mshadow::cpu::kDevMask ) {
306+ if (dev_mask == mshadow::cpu::kDevMask
307+ // We should allow MKLDNN conv to apply to the default storage as well.
308+ // Even with format conversion, MKLDNN conv should still be faster than
309+ // the native implementation.
310+ && (in_attrs->at (0 ) == kMKLDNNStorage
311+ || in_attrs->at (0 ) == kDefaultStorage )) {
307312 *dispatch_mode = DispatchMode::kFComputeEx ;
308313 (*out_attrs)[0 ] = kMKLDNNStorage ;
309314 return true ;
@@ -326,7 +331,12 @@ inline static bool backward_ConvStorageType(const nnvm::NodeAttrs& attrs,
326331 CHECK_EQ (out_attrs->size (), out_expected);
327332
328333#if MXNET_USE_MKLDNN == 1
329- if (dev_mask == mshadow::cpu::kDevMask ) {
334+ if (dev_mask == mshadow::cpu::kDevMask
335+ // We should allow MKLDNN conv to apply to the default storage as well.
336+ // Even with format conversion, MKLDNN conv should still be faster than
337+ // the native implementation.
338+ && (in_attrs->at (0 ) == kMKLDNNStorage
339+ || in_attrs->at (0 ) == kDefaultStorage )) {
330340 *dispatch_mode = DispatchMode::kFComputeEx ;
331341 for (size_t i = 0 ; i < out_attrs->size (); i++)
332342 (*out_attrs)[i] = kMKLDNNStorage ;
0 commit comments