[MXNET-1426] Fix the wrong result of sum, mean, argmin, argmax when inputs contain inf or nan#16234
[MXNET-1426] Fix the wrong result of sum, mean, argmin, argmax when inputs contain inf or nan#16234wkcn merged 21 commits intoapache:masterfrom
Conversation
iblislin
left a comment
There was a problem hiding this comment.
The Julia part looks fine for me.
|
Hi @marcoabreu @access2rohit , could you please help take a review? |
|
Hi @reminisce and @haojin2 , could you please help take a review? This PR makes the following functions consistent with NumPy. Thank you! |
|
Hi @eric-haibin-lin , could you please help take a review? Thank you so much! |
|
Would you mind also add what is the result before this fix? |
|
Hi @eric-haibin-lin , I have updated the test result : ) |
|
cc @reminisce |
|
Ping : ) |
3rdparty/mshadow/mshadow/base.h
Outdated
| } | ||
| template<> | ||
| MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) { | ||
| return (val.half_ & 0x7fff) > 0x7c00; |
There was a problem hiding this comment.
Can you turn these magic values into constants with documentation? While I get 0x7ffff, 0x7c00 for example, looks quite arbitrary.
There was a problem hiding this comment.
Hi @marcoabreu , I add two constants MSHADOW_HALF_SIGN_BIT and MSHADOW_HALF_EXPONENT_BITS in 3rdparty/mshadow/mshadow/half.h, and replace these two magic values.
| %(ndarray_ret.shape, numpy_ret.shape) | ||
| err = np.square(ndarray_ret - numpy_ret).mean() | ||
| assert err < 1E-4 | ||
| if check_dtype: |
There was a problem hiding this comment.
Could you elaborate why you're introducing so much branching into a test? If the results are inconsistent, we should rather improve the test instead of skipping the checks. I'd love to have more detail
There was a problem hiding this comment.
Hi @marcoabreu , here is the explanation.
-
So much branching
We need to test all reduce operators, likemin, max, argmin, argmax, sum, meanwhen the inputs contain-inf, +inf, nan. -
Skipping the checks
I replace the old check with a new one. : )
|
I'll merge after the feedback has been addressed :) Sorry for the delay |
Description
Hi, there.
I fix the wrong result of sum(inf, inf) and mean(inf, inf).
Test Case:
The wrong result in
mxnet_mkl-1.6.0b20191015-py2.py3-none-manylinux1_x86_64The correct result in this PR:
If we modify the
testfunction,Here is the result of NumPy.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
isnan_typedandisinf_typedin mshadowisnan_typedinsrc/operator/mshadow_op.hmshadow/extension/reduce_with_axis.hto support NaNargminandargmax.