-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
python/mxnet/ndarray/numpy/_op.py
Outdated
@@ -39,8 +39,7 @@ | |||
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', | |||
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', | |||
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', | |||
'nan_to_num'] | |||
|
|||
'nan_to_num', 'diagflat'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to other places to avoid conflicts, probably you can move to after full
.
Remember to move the function definitions too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for all other files.
python/mxnet/ndarray/numpy/_op.py
Outdated
def diagflat(arr, k=0): | ||
""" | ||
Create a two-dimensional array with the flattened input as a diagonal. | ||
Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameters
----------
python/mxnet/ndarray/numpy/_op.py
Outdated
|
||
Diagonal to set; 0, the default, corresponds to the "main" diagonal, a positive (negative) k giving the number of the diagonal above (below) the main. | ||
|
||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns
-------
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For docs you can almost copy from : https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/twodim_base.py#L288-L345.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For docs you can almost copy from : https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/twodim_base.py#L288-L345.
get
python/mxnet/symbol/numpy/_symbol.py
Outdated
[0,0,0,0,0,0,0,8]] | ||
""" | ||
return _npi.diagflat(arr, k=k) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one less blank line.
@@ -0,0 +1,213 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put those in np_matrix_op
files
struct NumpyDiagflatParam : public dmlc::Parameter<NumpyDiagflatParam> { | ||
int k; | ||
DMLC_DECLARE_PARAMETER(NumpyDiagflatParam) { | ||
DMLC_DECLARE_FIELD(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2-space indents
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for everywhere else.
|
||
|
||
template<typename xpu, bool back> | ||
void NumpyDiagflatOpProcess(const TBlob& in_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NumpyDiagflatOpImpl
src/operator/numpy/np_diagflat_op.cc
Outdated
DMLC_REGISTER_PARAMETER(NumpyDiagflatParam); | ||
|
||
NNVM_REGISTER_OP(_npi_diagflat) | ||
.describe(R"code(contructs a diagonal array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for doc here.
|
||
NumpyDiagflatOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), param, s, req); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 less blank line here.
python/mxnet/ndarray/numpy/_op.py
Outdated
Create a two-dimensional array with the flattened input as a diagonal. | ||
Parameters: | ||
|
||
v : array_like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v : array_like | |
arr : ndarray |
or you can fall back to numpy when arr is not ndarray type
python/mxnet/ndarray/numpy/_op.py
Outdated
Input data, which is flattened and set as the k-th diagonal of the output. | ||
k : int, optional | ||
|
||
Diagonal to set; 0, the default, corresponds to the "main" diagonal, a positive (negative) k giving the number of the diagonal above (below) the main. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too long
python/mxnet/symbol/numpy/_symbol.py
Outdated
@@ -4558,6 +4558,61 @@ def rot90(m, k=1, axes=(0, 1)): | |||
return _npi.rot90(m, k=k, axes=axes) | |||
|
|||
|
|||
@set_module('mxnet.ndarray.numpy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@set_module('mxnet.ndarray.numpy') | |
@set_module('mxnet.symbol.numpy') |
python/mxnet/symbol/numpy/_symbol.py
Outdated
|
||
Returns: | ||
|
||
out : ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out : ndarray | |
out : _Symbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for arr
vals_f = _np.array((100 * get_mat(5) + 1), order = 'F', dtype = 'l') | ||
vals_f = np.array(vals_f) | ||
|
||
OpArgMngr.add_workload('diagflat', A, k= 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get rid of all spaces after the =
266b9aa
to
73fa4d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, will merge as soon as CI passes
Description
(Brief description on what this PR is about)
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments