Fix dtype inference in arange_like operator#15930
Fix dtype inference in arange_like operator#15930eric-haibin-lin merged 4 commits intoapache:masterfrom
Conversation
eric-haibin-lin
left a comment
There was a problem hiding this comment.
Great. This will fix the bug with float16 inputs. Would you mind just adding a simple unit test to check the dtype with float16 inputs?
|
@eric-haibin-lin Did you observe any crash with fp16 input? With below code snippet, it doesn't seem to crash but just gives numpy.float32 output: import mxnet as mx
import numpy as np
x = mx.sym.Variable('x', dtype=np.float16)
y = mx.sym.reshape(x, shape=(0, 0, -1))
z = mx.sym.contrib.arange_like(y, axis=-1)
mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), graph_req='null')
mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(np.float16)
out = mod.forward(is_train=False)
print(out[0].dtype) |
|
No I didn't expect a crash. I expect it copies dtype attribute like other xx_like ops |
|
@eric-haibin-lin do you think the below code snippet can be used as a test case? import mxnet as mx
import numpy as np
dtypes = [np.float16, np.float32, np.float64]
for t in dtypes:
x = mx.sym.Variable('x', dtype=t)
y = mx.sym.reshape(x, shape=(0, 0, -1))
z = mx.sym.contrib.arange_like(y, axis=-1)
mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), graph_req='null')
mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t)
out = mod.forward(is_train=False)
assert out[0].dtype == np.float32 |
|
Yes. Could you also check the forwward output with [0, 1, 2,.. ] etc? |
|
I hope to reserve the dtype attribution, and there is a default action when dtype is None. |
Just want to provide the same user experience for |
Description
Remove the dtype argument from parameter structure and use ElemwiseType instead.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments