2525#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2626#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2727
28+ #include < algorithm>
2829#include < vector>
2930#include < string>
3031
@@ -381,11 +382,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
381382}
382383
383384template <typename xpu, typename LOP, typename ROP>
384- void MixedBinaryBackwardUseIn (const nnvm::NodeAttrs& attrs,
385+ void NumpyBinaryBackwardUseIn (const nnvm::NodeAttrs& attrs,
385386 const OpContext& ctx,
386387 const std::vector<TBlob>& inputs,
387388 const std::vector<OpReqType>& req,
388389 const std::vector<TBlob>& outputs) {
390+ using namespace mshadow ;
391+ using namespace mxnet_op ;
389392 CHECK_EQ (inputs.size (), 3U );
390393 CHECK_EQ (outputs.size (), 2U );
391394
@@ -396,7 +399,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
396399 return ;
397400 }
398401
399- PrintErrorMessage (attrs.op ->name , lhs.type_flag_ , rhs.type_flag_ );
402+ const TBlob& ograd = inputs[0 ];
403+ const TBlob& lgrad = outputs[0 ];
404+ const TBlob& rgrad = outputs[1 ];
405+
406+ if (common::is_float (lhs.type_flag_ ) || common::is_float (rhs.type_flag_ )) {
407+ // If any of the inputs is a float, it's the same type as the output
408+ // So 2 of the 3 tensors have the same data type
409+ Stream<xpu> *s = ctx.get_stream <xpu>();
410+ mxnet::TShape new_lshape, new_rshape, new_oshape;
411+ using namespace broadcast ;
412+ const bool need_bc = BinaryBroadcastShapeCompact (lgrad.shape_ , rgrad.shape_ , ograd.shape_ ,
413+ &new_lshape, &new_rshape, &new_oshape) != 0 ;
414+
415+ // Prepare all the temporary memory
416+ size_t workspace_size_l = 0 , workspace_size_r = 0 ;
417+ TBlob temp_tblob; // The TBlob for casted input data
418+ TBlob temp_igrad; // The TBlob for casted grad results
419+ size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_ ) ? lgrad.Size () : rgrad.Size ();
420+ Tensor<xpu, 1 , char > workspace;
421+
422+ MSHADOW_TYPE_SWITCH (ograd.type_flag_ , OType, {
423+ BROADCAST_NDIM_SWITCH (new_oshape.ndim (), ndim, {
424+ workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
425+ s, new_lshape, req[0 ], new_oshape, new_lshape, new_rshape);
426+ workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
427+ s, new_rshape, req[1 ], new_oshape, new_lshape, new_rshape);
428+ });
429+ size_t workspace_size = std::max (workspace_size_l, workspace_size_r);
430+ size_t cast_tensor_size = tensor_size * sizeof (OType);
431+ // Allocate the temporary memories now
432+ Tensor<xpu, 1 , char > temp_space =
433+ ctx.requested [0 ].get_space_typed <xpu, 1 , char >(
434+ Shape1 (workspace_size + cast_tensor_size * 2 ), s);
435+ // Tensor for temp_tblob
436+ Tensor<xpu, 1 , OType> temp_tblob_tensor (
437+ reinterpret_cast <OType*>(temp_space.dptr_ ),
438+ Shape1 (tensor_size), s);
439+ // Tensor for temp_igrad
440+ Tensor<xpu, 1 , OType> temp_igrad_tensor (
441+ reinterpret_cast <OType*>(temp_space.dptr_ ) + tensor_size,
442+ Shape1 (tensor_size), s);
443+ temp_tblob =
444+ TBlob (temp_tblob_tensor)
445+ .reshape (((lgrad.type_flag_ != ograd.type_flag_ ) ? lhs.shape_ : rhs.shape_ ));
446+ temp_igrad =
447+ TBlob (temp_igrad_tensor)
448+ .reshape (((lgrad.type_flag_ != ograd.type_flag_ ) ? lhs.shape_ : rhs.shape_ ));
449+ if (temp_igrad.Size () != 0 ) {
450+ Kernel<set_zero, xpu>::Launch (s, temp_igrad.Size (), temp_igrad.dptr <OType>());
451+ }
452+ workspace =
453+ Tensor<xpu, 1 , char >(temp_space.dptr_ + 2 * cast_tensor_size, Shape1 (workspace_size), s);
454+ });
455+ // Cast the input that does not have consistent type to temp_tblob
456+ CastCompute<xpu>(
457+ attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_ ) ? lhs : rhs)}, {kWriteTo }, {temp_tblob});
458+ if (!need_bc) {
459+ if (lhs.type_flag_ != ograd.type_flag_ ) {
460+ ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
461+ attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo , req[1 ]}, {temp_igrad, rgrad});
462+ } else {
463+ ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
464+ attrs, ctx, {ograd, lhs, temp_tblob}, {req[0 ], kWriteTo }, {lgrad, temp_igrad});
465+ }
466+ } else {
467+ if (lhs.type_flag_ != ograd.type_flag_ ) {
468+ MSHADOW_TYPE_SWITCH (ograd.type_flag_ , DType, {
469+ BROADCAST_NDIM_SWITCH (new_oshape.ndim (), NDim, {
470+ BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
471+ ctx, {ograd, temp_tblob, rhs}, {kWriteTo , req[1 ]}, {temp_igrad, rgrad},
472+ workspace, new_lshape, new_rshape, new_oshape);
473+ });
474+ });
475+ } else {
476+ MSHADOW_TYPE_SWITCH (ograd.type_flag_ , DType, {
477+ BROADCAST_NDIM_SWITCH (new_oshape.ndim (), NDim, {
478+ BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
479+ ctx, {ograd, lhs, temp_tblob}, {req[0 ], kWriteTo }, {lgrad, temp_igrad},
480+ workspace, new_lshape, new_rshape, new_oshape);
481+ });
482+ });
483+ }
484+ }
485+
486+ // If both inputs are floating numbers, cast the igrad to the input that has
487+ // the different data type
488+ if (common::is_float (lhs.type_flag_ ) && common::is_float (rhs.type_flag_ )) {
489+ if (lhs.type_flag_ != ograd.type_flag_ ) {
490+ CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0 ]}, {lgrad});
491+ } else {
492+ CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1 ]}, {rgrad});
493+ }
494+ }
495+ } else {
496+ // Case where both inputs are integer types, should not even do
497+ // backward computation for this case.
498+ PrintErrorMessage (attrs.op ->name , lhs.type_flag_ , rhs.type_flag_ );
499+ }
400500}
401501
402502} // namespace op
0 commit comments