Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8b58b78

Browse files
haojin2ptrendx
authored andcommitted
Mixed precison binary op backward (use in) for numpy (#16791)
* mixed precison binary op backward * reduce unix cpu runtime
1 parent e3e63fe commit 8b58b78

File tree

6 files changed

+162
-13
lines changed

6 files changed

+162
-13
lines changed

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
147147
"FCompute<cpu>",
148148
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
149149
#endif
150-
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
150+
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
151+
152+
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
153+
.set_num_inputs(3)
154+
.set_num_outputs(2)
155+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
156+
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
157+
[](const NodeAttrs& attrs){
158+
return std::vector<std::pair<int, int> >{{0, 1}};
159+
})
160+
.set_attr<FResourceRequest>("FResourceRequest",
161+
[](const NodeAttrs& attrs) {
162+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
163+
})
164+
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::right,
165+
mshadow_op::left>);
151166

152167
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
153168
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ NNVM_REGISTER_OP(_npi_multiply)
6464
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
6565
#endif
6666

67+
NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
68+
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::right,
69+
mshadow_op::left>);
70+
6771
NNVM_REGISTER_OP(_npi_mod)
6872
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
6973

src/operator/numpy/np_elemwise_broadcast_op.h

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2626
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
2727

28+
#include <algorithm>
2829
#include <vector>
2930
#include <string>
3031

@@ -381,11 +382,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
381382
}
382383

383384
template<typename xpu, typename LOP, typename ROP>
384-
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
385+
void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
385386
const OpContext& ctx,
386387
const std::vector<TBlob>& inputs,
387388
const std::vector<OpReqType>& req,
388389
const std::vector<TBlob>& outputs) {
390+
using namespace mshadow;
391+
using namespace mxnet_op;
389392
CHECK_EQ(inputs.size(), 3U);
390393
CHECK_EQ(outputs.size(), 2U);
391394

@@ -396,7 +399,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
396399
return;
397400
}
398401

399-
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
402+
const TBlob& ograd = inputs[0];
403+
const TBlob& lgrad = outputs[0];
404+
const TBlob& rgrad = outputs[1];
405+
406+
if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
407+
// If any of the inputs is a float, it's the same type as the output
408+
// So 2 of the 3 tensors have the same data type
409+
Stream<xpu> *s = ctx.get_stream<xpu>();
410+
mxnet::TShape new_lshape, new_rshape, new_oshape;
411+
using namespace broadcast;
412+
const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
413+
&new_lshape, &new_rshape, &new_oshape) != 0;
414+
415+
// Prepare all the temporary memory
416+
size_t workspace_size_l = 0, workspace_size_r = 0;
417+
TBlob temp_tblob; // The TBlob for casted input data
418+
TBlob temp_igrad; // The TBlob for casted grad results
419+
size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
420+
Tensor<xpu, 1, char> workspace;
421+
422+
MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
423+
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
424+
workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
425+
s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
426+
workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
427+
s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
428+
});
429+
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
430+
size_t cast_tensor_size = tensor_size * sizeof(OType);
431+
// Allocate the temporary memories now
432+
Tensor<xpu, 1, char> temp_space =
433+
ctx.requested[0].get_space_typed<xpu, 1, char>(
434+
Shape1(workspace_size + cast_tensor_size * 2), s);
435+
// Tensor for temp_tblob
436+
Tensor<xpu, 1, OType> temp_tblob_tensor(
437+
reinterpret_cast<OType*>(temp_space.dptr_),
438+
Shape1(tensor_size), s);
439+
// Tensor for temp_igrad
440+
Tensor<xpu, 1, OType> temp_igrad_tensor(
441+
reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
442+
Shape1(tensor_size), s);
443+
temp_tblob =
444+
TBlob(temp_tblob_tensor)
445+
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
446+
temp_igrad =
447+
TBlob(temp_igrad_tensor)
448+
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
449+
if (temp_igrad.Size() != 0) {
450+
Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
451+
}
452+
workspace =
453+
Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
454+
});
455+
// Cast the input that does not have consistent type to temp_tblob
456+
CastCompute<xpu>(
457+
attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
458+
if (!need_bc) {
459+
if (lhs.type_flag_ != ograd.type_flag_) {
460+
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
461+
attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
462+
} else {
463+
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
464+
attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
465+
}
466+
} else {
467+
if (lhs.type_flag_ != ograd.type_flag_) {
468+
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
469+
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
470+
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
471+
ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
472+
workspace, new_lshape, new_rshape, new_oshape);
473+
});
474+
});
475+
} else {
476+
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
477+
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
478+
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
479+
ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
480+
workspace, new_lshape, new_rshape, new_oshape);
481+
});
482+
});
483+
}
484+
}
485+
486+
// If both inputs are floating numbers, cast the igrad to the input that has
487+
// the different data type
488+
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
489+
if (lhs.type_flag_ != ograd.type_flag_) {
490+
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
491+
} else {
492+
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
493+
}
494+
}
495+
} else {
496+
// Case where both inputs are integer types, should not even do
497+
// backward computation for this case.
498+
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
499+
}
400500
}
401501

402502
} // namespace op

src/operator/tensor/elemwise_binary_broadcast_op.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
671671
const std::vector<OpReqType>& req,
672672
const std::vector<TBlob>& outputs);
673673

674+
template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
675+
void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
676+
const std::vector<TBlob>& inputs,
677+
const std::vector<OpReqType>& req,
678+
const std::vector<TBlob>& outputs,
679+
const mshadow::Tensor<xpu, 1, char>& workspace,
680+
const mxnet::TShape& new_lshape,
681+
const mxnet::TShape& new_rshape,
682+
const mxnet::TShape& new_oshape) {
683+
using namespace mshadow;
684+
using namespace mshadow::expr;
685+
using namespace broadcast;
686+
Stream<xpu> *s = ctx.get_stream<xpu>();
687+
const TBlob lgrad = outputs[0].reshape(new_lshape);
688+
const TBlob rgrad = outputs[1].reshape(new_rshape);
689+
const TBlob ograd = inputs[0].reshape(new_oshape);
690+
const TBlob lhs = inputs[1].reshape(new_lshape);
691+
const TBlob rhs = inputs[2].reshape(new_rshape);
692+
if (ograd.Size() != 0) {
693+
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
694+
ograd, lhs, rhs);
695+
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
696+
ograd, lhs, rhs);
697+
}
698+
}
699+
674700
template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
675701
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
676702
const std::vector<TBlob>& inputs,

src/operator/tensor/elemwise_unary_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
453453
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
454454
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
455455
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
456-
if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
457-
req[0] != kWriteInplace) {
456+
if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
457+
req[0] != kWriteInplace) && outputs[0].Size() != 0) {
458458
Assign(out, req[0], tcast<DstDType>(data));
459459
}
460460
});

tests/python/unittest/test_numpy_op.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
16831683
@with_seed()
16841684
@use_np
16851685
def test_np_mixed_precision_binary_funcs():
1686-
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype):
1686+
itypes = [np.bool, np.int8, np.int32, np.int64]
1687+
ftypes = [np.float16, np.float32, np.float64]
1688+
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype):
16871689
class TestMixedBinary(HybridBlock):
16881690
def __init__(self, func):
16891691
super(TestMixedBinary, self).__init__()
@@ -1717,13 +1719,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
17171719
use_broadcast=False, equal_nan=True)
17181720

17191721
funcs = {
1720-
'add': (-1.0, 1.0),
1721-
'subtract': (-1.0, 1.0),
1722-
'multiply': (-1.0, 1.0),
1722+
'add': (-1.0, 1.0, None, None),
1723+
'subtract': (-1.0, 1.0, None, None),
1724+
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
1725+
lambda y, x1, x2: _np.broadcast_to(x1, y.shape))
17231726
}
17241727

17251728
shape_pairs = [((3, 2), (3, 2)),
17261729
((3, 2), (3, 1)),
1730+
((3, 0), (3, 0)),
17271731
((3, 1), (3, 0)),
17281732
((0, 2), (1, 2)),
17291733
((2, 3, 4), (3, 1)),
@@ -1733,16 +1737,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
17331737
itypes = [np.bool, np.int8, np.int32, np.int64]
17341738
ftypes = [np.float16, np.float32, np.float64]
17351739
for func, func_data in funcs.items():
1736-
low, high = func_data
1740+
low, high, lgrad, rgrad = func_data
17371741
for lshape, rshape in shape_pairs:
17381742
for type1, type2 in itertools.product(itypes, ftypes):
1739-
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
1740-
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1)
1743+
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)
1744+
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1)
17411745

17421746
for type1, type2 in itertools.product(ftypes, ftypes):
17431747
if type1 == type2:
17441748
continue
1745-
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
1749+
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)
17461750

17471751

17481752
@with_seed()

0 commit comments

Comments
 (0)