Skip to content

Commit 3eb3ac6

Browse files
committed
update test
1 parent a28fa53 commit 3eb3ac6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/python/unittest/test_operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4332,7 +4332,9 @@ def check(data, idx):
43324332
data = mx.nd.array([212316236123621, -31231236374787,
43334333
-112372937128970, -13782787981728], dtype=dtype)
43344334
idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
4335-
assert (mx.nd.scatter_nd_acc(data, idx, shape=(1,)).asnumpy()[0] == data.asnumpy().sum())
4335+
scatter_nd_ret = mx.nd.scatter_nd_acc(data, idx, shape=(1,)).asscalar()
4336+
npy_ret = data.asnumpy().sum()
4337+
assert (scatter_nd_ret == npy_ret), "scatter_nd_acc={}, npy={}".format(scatter_nd_ret, npy_ret)
43364338

43374339
def compare_forw_backw_unary_op(
43384340
name, forward_mxnet_call, forward_numpy_call,

0 commit comments

Comments
 (0)