Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 295fc14

Browse files
ckt624reminisce
authored andcommitted
Numpy det and slogdet operators (#15861)
* Add alias. Add tests. Add slogdet tests. Add docs Change shapes Change tests. Change slogdet tests Change style. * Fix * Fix
1 parent 2df3282 commit 295fc14

File tree

5 files changed

+332
-106
lines changed

5 files changed

+332
-106
lines changed

python/mxnet/_numpy_op_doc.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,128 @@
2020
"""Doc placeholder for numpy ops with prefix _np."""
2121

2222

23+
def _np__linalg_det(a):
24+
"""
25+
det(a)
26+
27+
Compute the determinant of an array.
28+
29+
Parameters
30+
----------
31+
a : (..., M, M) ndarray
32+
Input array to compute determinants for.
33+
34+
Returns
35+
-------
36+
det : (...) ndarray
37+
Determinant of `a`.
38+
39+
See Also
40+
--------
41+
slogdet : Another way to represent the determinant, more suitable
42+
for large matrices where underflow/overflow may occur.
43+
44+
Notes
45+
-----
46+
47+
Broadcasting rules apply, see the `numpy.linalg` documentation for
48+
details.
49+
50+
The determinant is computed via LU factorization using the LAPACK
51+
routine z/dgetrf.
52+
53+
Examples
54+
--------
55+
The determinant of a 2-D array [[a, b], [c, d]] is ad - bc:
56+
57+
>>> a = np.array([[1, 2], [3, 4]])
58+
>>> np.linalg.det(a)
59+
-2.0
60+
61+
Computing determinants for a stack of matrices:
62+
63+
>>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ])
64+
>>> a.shape
65+
(3, 2, 2)
66+
>>> np.linalg.det(a)
67+
array([-2., -3., -8.])
68+
"""
69+
pass
70+
71+
72+
def _np__linalg_slogdet(a):
73+
"""
74+
slogdet(a)
75+
76+
Compute the sign and (natural) logarithm of the determinant of an array.
77+
78+
If an array has a very small or very large determinant, then a call to
79+
`det` may overflow or underflow. This routine is more robust against such
80+
issues, because it computes the logarithm of the determinant rather than
81+
the determinant itself.
82+
83+
Parameters
84+
----------
85+
a : (..., M, M) ndarray
86+
Input array, has to be a square 2-D array.
87+
88+
Returns
89+
-------
90+
sign : (...) ndarray
91+
A number representing the sign of the determinant. For a real matrix,
92+
this is 1, 0, or -1.
93+
logdet : (...) array_like
94+
The natural log of the absolute value of the determinant.
95+
96+
If the determinant is zero, then `sign` will be 0 and `logdet` will be
97+
-Inf. In all cases, the determinant is equal to ``sign * np.exp(logdet)``.
98+
99+
See Also
100+
--------
101+
det
102+
103+
Notes
104+
-----
105+
106+
Broadcasting rules apply, see the `numpy.linalg` documentation for
107+
details.
108+
109+
The determinant is computed via LU factorization using the LAPACK
110+
routine z/dgetrf.
111+
112+
113+
Examples
114+
--------
115+
The determinant of a 2-D array ``[[a, b], [c, d]]`` is ``ad - bc``:
116+
117+
>>> a = np.array([[1, 2], [3, 4]])
118+
>>> (sign, logdet) = np.linalg.slogdet(a)
119+
>>> (sign, logdet)
120+
(-1., 0.69314718055994529)
121+
>>> sign * np.exp(logdet)
122+
-2.0
123+
124+
Computing log-determinants for a stack of matrices:
125+
126+
>>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ])
127+
>>> a.shape
128+
(3, 2, 2)
129+
>>> sign, logdet = np.linalg.slogdet(a)
130+
>>> (sign, logdet)
131+
(array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154]))
132+
>>> sign * np.exp(logdet)
133+
array([-2., -3., -8.])
134+
135+
This routine succeeds where ordinary `det` does not:
136+
137+
>>> np.linalg.det(np.eye(500) * 0.1)
138+
0.0
139+
>>> np.linalg.slogdet(np.eye(500) * 0.1)
140+
(1., -1151.2925464970228)
141+
"""
142+
pass
143+
144+
23145
def _np_ones_like(a):
24146
"""
25147
Return an array of ones with the same shape and type as a given array.

src/operator/tensor/la_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ NNVM_REGISTER_OP(_backward_linalg_inverse)
941941

942942
NNVM_REGISTER_OP(_linalg_det)
943943
.add_alias("linalg_det")
944+
.add_alias("_np__linalg_det")
944945
.describe(R"code(Compute the determinant of a matrix.
945946
Input is a tensor *A* of dimension *n >= 2*.
946947
@@ -991,6 +992,7 @@ NNVM_REGISTER_OP(_backward_linalg_det)
991992
.set_attr<FCompute>("FCompute<cpu>", LaOpDetBackward<cpu, 1, det_backward>);
992993

993994
NNVM_REGISTER_OP(_linalg_slogdet)
995+
.add_alias("_np__linalg_slogdet")
994996
.add_alias("linalg_slogdet")
995997
.describe(R"code(Compute the sign and log of the determinant of a matrix.
996998
Input is a tensor *A* of dimension *n >= 2*.

src/operator/tensor/la_op.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ NNVM_REGISTER_OP(_backward_linalg_inverse)
100100
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, inverse_backward>);
101101

102102
NNVM_REGISTER_OP(_linalg_det)
103+
.add_alias("_np__linalg_det")
103104
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 1, det>);
104105

105106
NNVM_REGISTER_OP(_backward_linalg_det)
106107
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 1, det_backward>);
107108

108109
NNVM_REGISTER_OP(_linalg_slogdet)
110+
.add_alias("_np__linalg_slogdet")
109111
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 2, slogdet>);
110112

111113
NNVM_REGISTER_OP(_backward_linalg_slogdet)

src/operator/tensor/la_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#define MXNET_OPERATOR_TENSOR_LA_OP_H_
2727

2828
#include <mxnet/operator_util.h>
29+
#include <mxnet/imperative.h>
2930
#include <vector>
3031
#include <algorithm>
3132
#include "../mshadow_op.h"
@@ -428,7 +429,11 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs,
428429
CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal";
429430
mxnet::TShape out;
430431
if (ndim == 2) {
431-
out = mxnet::TShape(1, 1);
432+
if (Imperative::Get()->is_np_shape()) {
433+
out = mxnet::TShape(0, 1);
434+
} else {
435+
out = mxnet::TShape(1, 1);
436+
}
432437
} else {
433438
out = mxnet::TShape(in.begin(), in.end() - 2);
434439
}

0 commit comments

Comments
 (0)