Conversation
|
CC @tqchen if you have bandwidth |
9fc3389 to
ad97a7e
Compare
contrib/tvmop/basic/ufunc.py
Outdated
| return b, c | ||
|
|
||
|
|
||
| def reduce_axes(X, axes, reducer): |
There was a problem hiding this comment.
can we add some comments to elaborate the idea? e.g., meaning of axes. also can we move it to somewhere else so that other operators can reuse?
There was a problem hiding this comment.
Yes. Added in ufunc.py
contrib/tvmop/basic/ufunc.py
Outdated
| return s, [A, B, C] | ||
|
|
||
|
|
||
| def assign_by_req(a, req): |
There was a problem hiding this comment.
Shall we use the existing contrib/tvmop/utils.py or create a contrib/tvmop/basic/common.py?
src/operator/contrib/tvmop/ufunc.cc
Outdated
| funcname += "req_"; | ||
| MXNET_ASSIGN_REQ_SWITCH(req[k], req_type, { | ||
| if (req_type == kWriteTo) { | ||
| funcname += "kWriteTo"; |
src/operator/contrib/tvmop/ufunc.cc
Outdated
| // dispatch by backward | ||
| std::vector<int> ov, iv; | ||
| const TBlob& ograd = inputs[0], igrad = outputs[k]; | ||
| bool flag = ograd.size(0) != igrad.size(0); |
There was a problem hiding this comment.
better to use int and explicitly assign the value.
There was a problem hiding this comment.
What about expand it into a if-else?
src/operator/contrib/tvmop/ufunc.cc
Outdated
| } | ||
| TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end()); | ||
| TBlob ograd_tvm(ograd.reshape(oshape).dltensor()); | ||
| TBlob igrad_tvm(igrad.reshape(ishape).dltensor()); |
There was a problem hiding this comment.
please add some comments to elaborate the ideas.
| std::vector<int> ov, iv; | ||
| const TBlob& ograd = inputs[0], igrad = outputs[k]; | ||
| bool flag = ograd.size(0) != igrad.size(0); | ||
| for (int i = 0; i < ndim; ++i) { |
There was a problem hiding this comment.
If my understanding is correct, there seems to be an assumption that ograd.ndim = igrad.ndim, which is not necessarily true. I think you need to prepend axes before igrad if igrad.ndim < ograd.ndim and then use the logic here.
There was a problem hiding this comment.
Yes, igrad.ndim = ograd.ndim is assumed.
@yzhliu suggests padding the input to 5-dim, which is the largest possible dim supported by this op. The padding will 1) reduce the number of kernels (by a factor of 5) 2) handle the igrad.ndim < ograd.ndim issue. But there may be loss in performance.
I think prepending axes before igrad to make it ograd.dim requires more kernels, but the performance is better. It is a tradeoff.
There was a problem hiding this comment.
Please correct me if my understanding is wrong, but don't you still need kernels generated for ndims < 5 since you will collapse consecutive dimensions where reduction is performed? For example, given a 5d shape (2, 3, 4, 5, 6), and perform reduction on axis=(1, 2), the tblob will be first reshaped into (2, 12, 30), and then reduce on axis=1. In this case, do you need a kernel generated for 3D shapes?
There was a problem hiding this comment.
I think we can pad the shape after dimension collapse. In this case, the tblob will be reshaped into (2, 12, 30, 1, 1) and then reduce on axis=[1, 3].
There was a problem hiding this comment.
I see. I am in favor of the approach with less kernels generated. We can revisit the performance concern if that turns out to be an issue.
There was a problem hiding this comment.
I pushed a new version, where the inputs and outputs are padded to 5 dim.
Description
Use tvm to implement vadd backward with broadcast.
Changes
Comments
I think code for these may be further reused in the future. It'll be great if we can have a consistent interface for tvm op and hide the dispatch of things like req and backward.
Thank @yzhliu and @junrushao1994 for the brilliant "compressed bit string" idea.