diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 88eac339..b32d15fa 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -1223,7 +1223,7 @@ def in1d(ar1, ar2, *args, **kwargs): @implements(np.take) -def take(a, indices, axis=None, out=None, *args, **kwargs): +def take(a, indices, axis=None, out=None, mode="raise"): ret_units = getattr(a, "units", NULL_UNIT) if out is not None: @@ -1232,7 +1232,7 @@ def take(a, indices, axis=None, out=None, *args, **kwargs): out_view = None res = np.take._implementation( - np.asarray(a), indices, axis=axis, out=out_view, *args, **kwargs + np.asarray(a), indices, axis=axis, out=out_view, mode=mode ) if getattr(out, "units", None) is not None: diff --git a/unyt/array.py b/unyt/array.py index c3260030..f15d6952 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2126,6 +2126,16 @@ def dot(self, b, out=None): return ret def take(self, *args, **kwargs): + """method + + Return an array formed from the elements of `a` at the given indices. + + Refer to :func:`numpy.take` for full documentation. + + See also + -------- + numpy.take : equivalent function + """ from ._array_functions import take return take(self, *args, **kwargs)